From 3ef46ec4b502efcc1422878bef7e1cbda4bc575a Mon Sep 17 00:00:00 2001 From: nuglifeleoji Date: Fri, 13 Feb 2026 11:55:54 -0800 Subject: [PATCH] Add Mind2Web evaluation task with two candidate pool sizes Add Mind2Web web navigation task for ACE framework: - mind2web: ~200 candidate elements per step (199 negative + positives) - mind2web2: ~50 candidate elements per step (49 negative + positives) Each task includes: - prepare_data.py: Downloads Mind2Web from HuggingFace, converts to step-level ACE samples with candidate element selection formulation, performs stratified train/val/test split by domain - data_processor.py: Three-level evaluation (element index + operation type + value matching) with flexible parsing - run.py: Standard ACE training/evaluation script with offline, online, and eval_only modes - data/sample_config.json: Data path configuration The two versions enable studying the effect of candidate pool size on ACE's context learning performance for web agent tasks. --- eval/mind2web/data/sample_config.json | 12 + eval/mind2web/data_processor.py | 227 +++++++++++++++++ eval/mind2web/prepare_data.py | 301 +++++++++++++++++++++++ eval/mind2web/run.py | 232 ++++++++++++++++++ eval/mind2web2/data/sample_config.json | 12 + eval/mind2web2/data_processor.py | 230 ++++++++++++++++++ eval/mind2web2/prepare_data.py | 321 +++++++++++++++++++++++++ eval/mind2web2/run.py | 230 ++++++++++++++++++ 8 files changed, 1565 insertions(+) create mode 100644 eval/mind2web/data/sample_config.json create mode 100644 eval/mind2web/data_processor.py create mode 100644 eval/mind2web/prepare_data.py create mode 100644 eval/mind2web/run.py create mode 100644 eval/mind2web2/data/sample_config.json create mode 100644 eval/mind2web2/data_processor.py create mode 100644 eval/mind2web2/prepare_data.py create mode 100644 eval/mind2web2/run.py diff --git a/eval/mind2web/data/sample_config.json b/eval/mind2web/data/sample_config.json new file mode 100644 index 0000000..0881dd6 --- /dev/null +++ b/eval/mind2web/data/sample_config.json @@ -0,0 +1,12 @@ +{ + "mind2web": { + "train_data": "./eval/mind2web/data/mind2web_train.jsonl", + "val_data": "./eval/mind2web/data/mind2web_val.jsonl", + "test_data": "./eval/mind2web/data/mind2web_test.jsonl" + }, + "mind2web_small": { + "train_data": "./eval/mind2web/data/mind2web_train_200.jsonl", + "val_data": "./eval/mind2web/data/mind2web_val.jsonl", + "test_data": "./eval/mind2web/data/mind2web_test.jsonl" + } +} diff --git a/eval/mind2web/data_processor.py b/eval/mind2web/data_processor.py new file mode 100644 index 0000000..95cfdb4 --- /dev/null +++ b/eval/mind2web/data_processor.py @@ -0,0 +1,227 @@ +""" +Data processor for Mind2Web web navigation task. + +Task: Given a webpage with ~200 candidate elements and a navigation task +description with action history, select the correct element and specify +the action (CLICK, TYPE, or SELECT with value). + +Evaluation: Three-level matching (element index + operation type + value). +""" +import os +import json +import re +from typing import List, Dict, Any + + +def load_data(data_path: str) -> List[Dict[str, Any]]: + """ + Load and process data from a JSONL file. + + Args: + data_path: Path to the JSONL file + + Returns: + List of dictionaries containing the data + """ + if not os.path.exists(data_path): + raise FileNotFoundError(f"Data file not found: {data_path}") + data = [] + with open(data_path, 'r', encoding='utf-8') as f: + for line in f: + line = line.strip() + if line: + data.append(json.loads(line)) + print(f"Loaded {len(data)} samples from {data_path}") + return data + + +class DataProcessor: + """ + Processor for Mind2Web web navigation task. + + Handles element selection + action prediction on web pages. + Each sample presents ~200 candidate elements and the model must select + the correct one and specify the action (CLICK/TYPE/SELECT). + """ + + def __init__(self, task_name: str): + """ + Initialize the data processor. + + Args: + task_name: The name of the task (e.g., 'mind2web') + """ + self.task_name = task_name + + def process_task_data(self, raw_data: List[Dict]) -> List[Dict]: + """ + Convert raw Mind2Web data into standardized format for ACE. + + Input format (from JSONL, produced by prepare_data.py): + { + "context": "Candidate elements on the current webpage:\n[0] ...\\n...", + "question": "Task: ... Website: ...\\nActions completed:\\n...\\nSelect...", + "target": "[7] SELECT [combobox] Reservation type: Pickup", + "annotation_id": "...", + "step_idx": 0, + "total_steps": 11, + "domain": "Travel", + "website": "exploretock", + "action_repr": "[combobox] Reservation type -> SELECT: Pickup", + "operation": {"op": "SELECT", "value": "Pickup"}, + "n_candidates": 50, + "correct_candidate_idx": 7 + } + + Output format (standardized for ACE): + { + "context": "", + "question": "", + "target": "[7] SELECT [combobox] Reservation type: Pickup", + "others": { ... metadata ... } + } + + Args: + raw_data: Raw data loaded from JSONL + + Returns: + List of dicts in standardized format + """ + processed_data = [] + + for item in raw_data: + processed_item = { + "context": item.get("context", ""), + "question": item.get("question", ""), + "target": item.get("target", ""), + "others": { + "annotation_id": item.get("annotation_id", ""), + "step_idx": item.get("step_idx", 0), + "total_steps": item.get("total_steps", 0), + "domain": item.get("domain", ""), + "website": item.get("website", ""), + "action_repr": item.get("action_repr", ""), + "operation": item.get("operation", {}), + "n_candidates": item.get("n_candidates", 0), + "correct_candidate_idx": item.get("correct_candidate_idx", -1), + "task": self.task_name + } + } + processed_data.append(processed_item) + + return processed_data + + def _parse_prediction(self, text: str) -> dict: + """ + Parse a model prediction to extract element index, operation, and value. + + Expected format: "[idx] OP [tag] element_text: value" + But models may produce variations, so we use flexible parsing. + + Returns: + dict with keys: 'element_idx', 'op', 'value' (any may be None) + """ + result = {"element_idx": None, "op": None, "value": None} + text = text.strip() + + # Try to extract element index: [number] + idx_match = re.search(r'\[(\d+)\]', text) + if idx_match: + result["element_idx"] = int(idx_match.group(1)) + + # Try to extract operation type + for op in ["SELECT", "TYPE", "CLICK"]: + if op in text.upper(): + result["op"] = op + break + + # Try to extract value (after the last colon for TYPE/SELECT) + if result["op"] in ["SELECT", "TYPE"]: + # Look for ": value" pattern + value_match = re.search(r':\s*(.+?)$', text) + if value_match: + result["value"] = value_match.group(1).strip() + + return result + + def answer_is_correct(self, predicted: str, ground_truth: str) -> bool: + """ + Check if prediction matches ground truth. + + Checks three components with decreasing strictness: + 1. Element index must match (most important) + 2. Operation type must match + 3. Value must match for TYPE/SELECT (flexible matching) + + Args: + predicted: Model's answer + ground_truth: Ground truth answer + + Returns: + bool: True if answer is correct + """ + pred = self._parse_prediction(predicted) + truth = self._parse_prediction(ground_truth) + + # Element index must match + if pred["element_idx"] is None or pred["element_idx"] != truth["element_idx"]: + return False + + # Operation type must match + if pred["op"] is None or pred["op"] != truth["op"]: + return False + + # For CLICK, element + op match is sufficient + if truth["op"] == "CLICK": + return True + + # For TYPE/SELECT, value must also match + if truth["value"] is None: + return True # No value to check + + if pred["value"] is None: + return False + + # Flexible value comparison (case-insensitive, strip whitespace) + return pred["value"].strip().lower() == truth["value"].strip().lower() + + def evaluate_accuracy(self, out: List[str], target: List[str]) -> float: + """ + Calculate accuracy across multiple predictions. + + Also reports sub-metrics: + - Element selection accuracy + - Operation type accuracy + - Full match accuracy (element + op + value) + + Args: + out: List of model predictions + target: List of ground truth targets + + Returns: + Accuracy as float between 0 and 1 + """ + if len(out) != len(target): + raise ValueError("Predictions and ground truths must have the same length.") + + correct_count = 0 + elem_correct = 0 + op_correct = 0 + + for predicted, ground_truth in zip(out, target): + pred = self._parse_prediction(predicted) + truth = self._parse_prediction(ground_truth) + + if pred["element_idx"] == truth["element_idx"] and truth["element_idx"] is not None: + elem_correct += 1 + if pred["op"] == truth["op"] and truth["op"] is not None: + op_correct += 1 + if self.answer_is_correct(predicted, ground_truth): + correct_count += 1 + + n = len(out) if out else 1 + print(f" Element selection accuracy: {elem_correct}/{n} = {elem_correct/n:.3f}") + print(f" Operation type accuracy: {op_correct}/{n} = {op_correct/n:.3f}") + print(f" Full match accuracy: {correct_count}/{n} = {correct_count/n:.3f}") + + accuracy = correct_count / n + return accuracy diff --git a/eval/mind2web2/prepare_data.py b/eval/mind2web2/prepare_data.py new file mode 100644 index 0000000..2e6f38a --- /dev/null +++ b/eval/mind2web2/prepare_data.py @@ -0,0 +1,321 @@ +#!/usr/bin/env python3 +""" +Prepare Mind2Web data for ACE framework (50-candidate version). + +Downloads the Mind2Web dataset from HuggingFace and converts it into +step-level ACE samples with candidate element selection formulation. + +Each web navigation task has multiple steps. Each step becomes one ACE sample: +- context: Compact list of ~50 candidate elements (tag + text + key attributes) +- question: Task description + previous action history +- target: Correct action representation (OP + element info + value) + +Usage: + python -m eval.mind2web2.prepare_data +""" +import os +import re +import json +import random +from collections import Counter + +# Number of negative candidates to sample per step (+ all positives) +# 49 negatives + 1 positive = ~50 total candidates +MAX_NEG_CANDIDATES = 49 +# Random seed for reproducibility +SEED = 42 + +OUTPUT_DIR = "./eval/mind2web2/data" + + +def extract_element_text(html: str, backend_node_id: str, max_chars: int = 200) -> str: + """ + Extract visible text content for an element identified by backend_node_id + from the cleaned HTML. + + Searches for the element tag, then finds nodes in the ~500 chars + following it. + """ + pattern = f'<\\w+\\s+backend_node_id="{backend_node_id}"[^>]*>' + match = re.search(pattern, html) + if not match: + return "" + + start = match.start() + snippet = html[start:start + 600] + + # Extract text node contents within this snippet + texts = re.findall(r'([^<]*)', snippet) + text_content = " ".join(t.strip() for t in texts if t.strip()) + + if len(text_content) > max_chars: + text_content = text_content[:max_chars] + "..." + + return text_content + + +def get_candidate_repr(candidate: dict, html: str, idx: int) -> str: + """ + Create a compact text representation of a candidate element. + + Format: [idx] "text_content" (id=..., name=..., ...) + """ + tag = candidate["tag"] + backend_id = candidate["backend_node_id"] + + # Extract text from HTML + text = extract_element_text(html, backend_id) + + # Get useful attributes + try: + attrs = json.loads(candidate["attributes"]) + except (json.JSONDecodeError, TypeError): + attrs = {} + + useful_attrs = {} + for key in ["id", "name", "aria-label", "placeholder", "alt", "title", + "type", "role", "href", "value"]: + if key in attrs: + val = str(attrs[key])[:80] + useful_attrs[key] = val + + # Build representation + parts = [f"[{idx}] <{tag}>"] + if text: + parts.append(f'"{text}"') + if useful_attrs: + attr_str = ", ".join(f'{k}="{v}"' for k, v in useful_attrs.items()) + parts.append(f"({attr_str})") + + return " ".join(parts) + + +def build_target(action_repr: str, correct_idx: int, operation: dict) -> str: + """ + Build the target answer string. + + Format: [idx] OP element_description: value + e.g.: [3] SELECT [combobox] Reservation type: Pickup + """ + op = operation["op"] + value = operation.get("value", "") + + # The action_repr format is: [tag] text -> OP: value + # Extract the element description from action_repr + # e.g. "[combobox] Reservation type -> SELECT: Pickup" + elem_desc = action_repr.split(" -> ")[0].strip() if " -> " in action_repr else action_repr + + if value: + return f"[{correct_idx}] {op} {elem_desc}: {value}" + else: + return f"[{correct_idx}] {op} {elem_desc}" + + +def process_step(task: dict, step_idx: int, rng: random.Random) -> dict: + """ + Convert a single step within a task into an ACE-format sample. + + Returns None if the step can't be processed (e.g., no pos_candidates). + """ + action = task["actions"][step_idx] + action_repr = task["action_reprs"][step_idx] + html = action["cleaned_html"] + operation = action["operation"] + + pos_candidates = action["pos_candidates"] + neg_candidates = action["neg_candidates"] + + if not pos_candidates: + return None + + # Sample negative candidates + n_neg = min(MAX_NEG_CANDIDATES, len(neg_candidates)) + sampled_neg = rng.sample(neg_candidates, n_neg) if n_neg > 0 else [] + + # Combine and shuffle candidates + all_candidates = [] + + for pc in pos_candidates: + all_candidates.append(("pos", pc)) + for nc in sampled_neg: + all_candidates.append(("neg", nc)) + + rng.shuffle(all_candidates) + + # Build candidate list and find correct index + candidate_reprs = [] + correct_idx = -1 + for i, (label, cand) in enumerate(all_candidates): + repr_str = get_candidate_repr(cand, html, i) + candidate_reprs.append(repr_str) + if label == "pos" and correct_idx == -1: + correct_idx = i + + if correct_idx == -1: + return None + + # Build context (candidate list) + context = "Candidate elements on the current webpage:\n" + "\n".join(candidate_reprs) + + # Build question (task + history) + question_parts = [f"Task: {task['confirmed_task']}"] + question_parts.append(f"Website: {task['website']} (Domain: {task['domain']})") + + if step_idx > 0: + question_parts.append("\nActions completed so far:") + for j in range(step_idx): + question_parts.append(f" Step {j+1}: {task['action_reprs'][j]}") + + question_parts.append( + "\nFrom the candidate elements listed in the context, " + "select the correct element index and specify the action " + "(CLICK, TYPE, or SELECT with value if applicable).\n" + "Answer format: [element_index] ACTION_TYPE [element_tag] element_text: value" + ) + + question = "\n".join(question_parts) + + # Build target + target = build_target(action_repr, correct_idx, operation) + + return { + "context": context, + "question": question, + "target": target, + "annotation_id": task["annotation_id"], + "step_idx": step_idx, + "total_steps": len(task["actions"]), + "domain": task["domain"], + "website": task["website"], + "action_repr": action_repr, + "operation": operation, + "n_candidates": len(all_candidates), + "correct_candidate_idx": correct_idx + } + + +def main(): + from datasets import load_dataset + + print("=" * 60) + print("Mind2Web Data Preparation for ACE (50-candidate version)") + print("=" * 60) + + # Load dataset + print("\nLoading Mind2Web dataset from HuggingFace...") + ds = load_dataset("osunlp/Mind2Web", split="train") + print(f"Loaded {len(ds)} tasks") + + rng = random.Random(SEED) + + # Convert all steps to ACE samples + print("\nConverting steps to ACE samples...") + all_samples = [] + skipped = 0 + + for task_idx in range(len(ds)): + task = ds[task_idx] + n_steps = len(task["actions"]) + + for step_idx in range(n_steps): + sample = process_step(task, step_idx, rng) + if sample: + all_samples.append(sample) + else: + skipped += 1 + + if (task_idx + 1) % 100 == 0: + print(f" Processed {task_idx + 1}/{len(ds)} tasks, " + f"{len(all_samples)} samples so far...") + + print(f"\nTotal samples: {len(all_samples)} (skipped {skipped} steps with no pos_candidates)") + + # Domain statistics + domain_counts = Counter(s["domain"] for s in all_samples) + print(f"\nDomain distribution:") + for domain, cnt in domain_counts.most_common(): + print(f" {domain}: {cnt} samples") + + # Split by task annotation_id (stratified by domain) + # Group tasks by domain + task_ids_by_domain = {} + for task in ds: + domain = task["domain"] + if domain not in task_ids_by_domain: + task_ids_by_domain[domain] = [] + task_ids_by_domain[domain].append(task["annotation_id"]) + + train_task_ids = set() + val_task_ids = set() + test_task_ids = set() + + for domain, task_ids in task_ids_by_domain.items(): + rng.shuffle(task_ids) + n = len(task_ids) + n_train = int(n * 0.6) + n_val = int(n * 0.15) + + train_task_ids.update(task_ids[:n_train]) + val_task_ids.update(task_ids[n_train:n_train + n_val]) + test_task_ids.update(task_ids[n_train + n_val:]) + + # Split samples + train_samples = [s for s in all_samples if s["annotation_id"] in train_task_ids] + val_samples = [s for s in all_samples if s["annotation_id"] in val_task_ids] + test_samples = [s for s in all_samples if s["annotation_id"] in test_task_ids] + + print(f"\n=== Data Split (by task, stratified by domain) ===") + print(f"Train: {len(train_samples)} samples from {len(train_task_ids)} tasks") + print(f"Val: {len(val_samples)} samples from {len(val_task_ids)} tasks") + print(f"Test: {len(test_samples)} samples from {len(test_task_ids)} tasks") + + # Save to JSONL + os.makedirs(OUTPUT_DIR, exist_ok=True) + + for name, samples in [("train", train_samples), ("val", val_samples), ("test", test_samples)]: + path = os.path.join(OUTPUT_DIR, f"mind2web2_{name}.jsonl") + with open(path, "w", encoding="utf-8") as f: + for s in samples: + f.write(json.dumps(s, ensure_ascii=True) + "\n") + print(f"Saved {len(samples)} samples to {path}") + + # Also save a smaller train subset for quick experiments + # Take first 200 train samples (roughly ~30 tasks) + train_small = train_samples[:200] + path = os.path.join(OUTPUT_DIR, "mind2web2_train_200.jsonl") + with open(path, "w", encoding="utf-8") as f: + for s in train_small: + f.write(json.dumps(s, ensure_ascii=True) + "\n") + print(f"Saved {len(train_small)} samples to {path} (small train set)") + + # Print sample + print(f"\n=== Example Sample ===") + ex = train_samples[0] + print(f"Context (first 500 chars):\n{ex['context'][:500]}\n...") + print(f"\nQuestion:\n{ex['question']}") + print(f"\nTarget: {ex['target']}") + print(f"\nMetadata: domain={ex['domain']}, website={ex['website']}, " + f"step={ex['step_idx']+1}/{ex['total_steps']}, " + f"n_candidates={ex['n_candidates']}") + + # Save sample_config.json + config = { + "mind2web2": { + "train_data": "./eval/mind2web2/data/mind2web2_train.jsonl", + "val_data": "./eval/mind2web2/data/mind2web2_val.jsonl", + "test_data": "./eval/mind2web2/data/mind2web2_test.jsonl" + }, + "mind2web2_small": { + "train_data": "./eval/mind2web2/data/mind2web2_train_200.jsonl", + "val_data": "./eval/mind2web2/data/mind2web2_val.jsonl", + "test_data": "./eval/mind2web2/data/mind2web2_test.jsonl" + } + } + config_path = os.path.join(OUTPUT_DIR, "sample_config.json") + with open(config_path, "w") as f: + json.dump(config, f, indent=4) + print(f"\nSaved config to {config_path}") + + +if __name__ == "__main__": + main() diff --git a/eval/mind2web2/run.py b/eval/mind2web2/run.py new file mode 100644 index 0000000..99ed07c --- /dev/null +++ b/eval/mind2web2/run.py @@ -0,0 +1,230 @@ +import os +import json +import argparse +from .data_processor import DataProcessor, load_data + +from ace import ACE + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser(description='ACE System - Mind2Web (50 candidates)') + + # Task configuration + parser.add_argument("--task_name", type=str, required=True, + help="Name of the task (e.g., 'mind2web2', 'mind2web2_small')") + parser.add_argument("--initial_playbook_path", type=str, default=None, + help="Path to initial playbook (optional)") + parser.add_argument("--mode", type=str, default="offline", + choices=["offline", "online", "eval_only"], + help="Run mode: 'offline' for offline training with validation, " + "'online' for online training and testing on test split, " + "'eval_only' for testing only with provided playbook") + + # Model configuration + parser.add_argument("--api_provider", type=str, default="sambanova", + choices=["sambanova", "together", "openai"], help="API provider") + parser.add_argument("--generator_model", type=str, + default="DeepSeek-V3.1", + help="Model for generator") + parser.add_argument("--reflector_model", type=str, + default="DeepSeek-V3.1", + help="Model for reflector") + parser.add_argument("--curator_model", type=str, + default="DeepSeek-V3.1", + help="Model for curator") + + # Training configuration + parser.add_argument("--num_epochs", type=int, default=1, + help="Number of training epochs") + parser.add_argument("--max_num_rounds", type=int, default=3, + help="Max reflection rounds for incorrect answers") + parser.add_argument("--curator_frequency", type=int, default=1, + help="Run curator every N steps") + parser.add_argument("--eval_steps", type=int, default=100, + help="Evaluate every N steps") + parser.add_argument("--online_eval_frequency", type=int, default=15, + help="Update playbook every N samples for evaluation in online mode") + parser.add_argument("--save_steps", type=int, default=50, + help="Save intermediate playbooks every N steps") + + # System configuration + parser.add_argument("--max_tokens", type=int, default=4096, + help="Max tokens for LLM responses") + parser.add_argument("--playbook_token_budget", type=int, default=80000, + help="Total token budget for playbook") + parser.add_argument("--test_workers", type=int, default=20, + help="Number of parallel workers for testing") + + # Prompt configuration + parser.add_argument("--json_mode", action="store_true", + help="Enable JSON mode for LLM calls") + parser.add_argument("--no_ground_truth", action="store_true", + help="Don't use ground truth in reflection") + + # Bulletpoint analyzer configuration + parser.add_argument("--use_bulletpoint_analyzer", action="store_true", + help="Enable bulletpoint analyzer for deduplication and merging") + parser.add_argument("--bulletpoint_analyzer_threshold", type=float, default=0.90, + help="Similarity threshold for bulletpoint analyzer (0-1, default: 0.90)") + + # Skip initial test evaluation (useful when you already have baseline results) + parser.add_argument("--skip_initial_test", action="store_true", + help="Skip initial test evaluation in offline mode to save time") + + # Output configuration + parser.add_argument("--save_path", type=str, required=True, + help="Directory to save results") + + return parser.parse_args() + + +def preprocess_data(task_name, config, mode): + """ + Load and preprocess data for the specified task. + + Args: + task_name: Name of the task + config: Configuration dictionary with data paths + mode: Run mode ('offline', 'online', or 'eval_only') + + Returns: + Tuple of (train_samples, val_samples, test_samples, data_processor) + """ + processor = DataProcessor(task_name=task_name) + + # For online and eval_only modes, only load test data + if mode in ["online", "eval_only"]: + train_samples = None + val_samples = None + + if "test_data" in config: + test_samples = load_data(config["test_data"]) + test_samples = processor.process_task_data(test_samples) + else: + raise ValueError(f"{mode} mode requires test data in config.") + + if mode == "online": + print(f"Online mode: Training and testing on {len(test_samples)} examples") + else: + print(f"Eval only mode: Testing on {len(test_samples)} examples") + + # For offline mode, load train, val, and optionally test data + else: + train_samples = load_data(config["train_data"]) + val_samples = load_data(config["val_data"]) + train_samples = processor.process_task_data(train_samples) + val_samples = processor.process_task_data(val_samples) + + if "test_data" in config: + test_samples = load_data(config["test_data"]) + test_samples = processor.process_task_data(test_samples) + else: + test_samples = [] + + print(f"Offline mode: Training on {len(train_samples)} examples, " + f"validating on {len(val_samples)}, testing on {len(test_samples)}") + + return train_samples, val_samples, test_samples, processor + + +def load_initial_playbook(path): + """Load initial playbook if provided.""" + if path and os.path.exists(path): + with open(path, 'r') as f: + return f.read() + return None + + +def main(): + """Main execution function.""" + args = parse_args() + + print(f"\n{'='*60}") + print(f"ACE SYSTEM - Mind2Web (50 candidates)") + print(f"{'='*60}") + print(f"Task: {args.task_name}") + print(f"Mode: {args.mode.upper().replace('_', ' ')}") + print(f"Generator Model: {args.generator_model}") + print(f"{'='*60}\n") + + # Load data + with open("./eval/mind2web2/data/sample_config.json", 'r') as f: + task_config = json.load(f) + + if args.task_name not in task_config: + raise ValueError(f"Unknown task: {args.task_name}. " + f"Available: {list(task_config.keys())}") + + train_samples, val_samples, test_samples, data_processor = preprocess_data( + args.task_name, + task_config[args.task_name], + args.mode + ) + + # Load initial playbook (or use empty if None provided) + initial_playbook = load_initial_playbook(args.initial_playbook_path) + if initial_playbook: + print(f"Loaded initial playbook from {args.initial_playbook_path}\n") + else: + print("Using empty playbook as initial playbook\n") + + # Create ACE system + ace_system = ACE( + api_provider=args.api_provider, + generator_model=args.generator_model, + reflector_model=args.reflector_model, + curator_model=args.curator_model, + max_tokens=args.max_tokens, + initial_playbook=initial_playbook, + use_bulletpoint_analyzer=args.use_bulletpoint_analyzer, + bulletpoint_analyzer_threshold=args.bulletpoint_analyzer_threshold + ) + + # Prepare configuration + config = { + 'num_epochs': args.num_epochs, + 'max_num_rounds': args.max_num_rounds, + 'curator_frequency': args.curator_frequency, + 'eval_steps': args.eval_steps, + 'online_eval_frequency': args.online_eval_frequency, + 'save_steps': args.save_steps, + 'playbook_token_budget': args.playbook_token_budget, + 'task_name': args.task_name, + 'mode': args.mode, + 'json_mode': args.json_mode, + 'no_ground_truth': args.no_ground_truth, + 'save_dir': args.save_path, + 'test_workers': args.test_workers, + 'initial_playbook_path': args.initial_playbook_path, + 'use_bulletpoint_analyzer': args.use_bulletpoint_analyzer, + 'bulletpoint_analyzer_threshold': args.bulletpoint_analyzer_threshold, + 'api_provider': args.api_provider + } + + # If skip_initial_test, don't pass test_samples during offline training + run_test_samples = test_samples + if args.mode == "offline" and args.skip_initial_test: + print("Skipping test evaluation (--skip_initial_test)\n") + run_test_samples = None + + # Execute using the unified run method + try: + results = ace_system.run( + mode=args.mode, + train_samples=train_samples, + val_samples=val_samples, + test_samples=run_test_samples, + data_processor=data_processor, + config=config + ) + except UnboundLocalError as e: + print(f"\nError: {e}. This likely means all samples failed to evaluate.") + print("Check the logs for details on individual sample failures.") + results = {"accuracy": 0.0, "correct": 0, "total": 0} + + print(f"\nFinal results: {results}") + + +if __name__ == "__main__": + main()