diff --git a/.env.example b/.env.example index c053773..47180e1 100644 --- a/.env.example +++ b/.env.example @@ -22,6 +22,24 @@ PARALLEL_PHASE2=true BATCH_SIZE=15 MAX_PAGES=5000 +# MCTS — Adaptive +MCTS_ADAPTIVE=true +MCTS_MIN_ITERATIONS=8 +MCTS_MAX_ITERATIONS=50 +MCTS_META_MIN_ITERATIONS=5 +MCTS_META_MAX_ITERATIONS=30 +MCTS_CONVERGENCE_WINDOW=4 +MCTS_CONVERGENCE_VARIANCE_THRESHOLD=0.01 +MCTS_TOP_K_STABLE_ROUNDS=3 +MCTS_PRUNING=true +MCTS_PRUNING_MIN_VISITS=3 +MCTS_PRUNING_REWARD_THRESHOLD=0.25 +MCTS_ADAPTIVE_EXPLORATION=true +MCTS_EXPLORATION_START=2.0 +MCTS_EXPLORATION_END=0.5 +MCTS_EXPLORATION_DECAY=linear +MCTS_SIMULATION_BATCH_SIZE=4 + # Storage TREERAG_DATA_DIR=.treerag_data AUTO_REINDEX=true diff --git a/tests/test_config.py b/tests/test_config.py index 008160b..7bec505 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -32,6 +32,32 @@ def test_defaults(self): assert config.top_k_documents == 3 assert config.parallel_phase2 is True + def test_adaptive_defaults(self): + config = MCTSConfig() + assert config.adaptive is True + assert config.min_iterations == 8 + assert config.max_iterations == 50 + assert config.meta_min_iterations == 5 + assert config.meta_max_iterations == 30 + assert config.convergence_window == 4 + assert config.convergence_variance_threshold == 0.01 + assert config.top_k_stable_rounds == 3 + assert config.pruning_enabled is True + assert config.pruning_min_visits == 3 + assert config.pruning_reward_threshold == 0.25 + assert config.adaptive_exploration is True + assert config.exploration_start == 2.0 + assert config.exploration_end == 0.5 + assert config.exploration_decay == "linear" + assert config.simulation_batch_size == 4 + + def test_adaptive_disabled(self): + config = MCTSConfig(adaptive=False) + assert config.adaptive is False + # Other adaptive defaults still set + assert config.min_iterations == 8 + assert config.max_iterations == 50 + class TestIndexerConfig: def test_defaults(self): @@ -71,3 +97,20 @@ def test_from_env_custom(self): assert config.mcts.iterations == 40 assert config.indexer.batch_size == 10 assert config.folder.base_dir == "/tmp/test_data" + + def test_from_env_adaptive(self): + env = { + "MCTS_ADAPTIVE": "false", + "MCTS_MIN_ITERATIONS": "10", + "MCTS_MAX_ITERATIONS": "60", + "MCTS_PRUNING": "false", + "MCTS_EXPLORATION_START": "1.8", + } + with patch.dict(os.environ, env): + config = TreeRAGConfig.from_env() + assert config.mcts.adaptive is False + assert config.mcts.min_iterations == 10 + assert config.mcts.max_iterations == 60 + assert config.mcts.pruning_enabled is False + assert config.mcts.exploration_start == 1.8 + assert config.mcts.simulation_batch_size == 4 # default unchanged diff --git a/tests/test_models.py b/tests/test_models.py index fa3e40c..820d7ea 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -4,7 +4,7 @@ import tempfile from pathlib import Path -from treerag.models import TreeNode, DocumentIndex +from treerag.models import TreeNode, DocumentIndex, SearchStats class TestTreeNode: @@ -93,6 +93,23 @@ def test_reset_mcts_state(self): assert child.visit_count == 0 assert child.total_reward == 0.0 + def test_pruned_default(self): + node = TreeNode(node_id="0001", title="Test", summary="Test", + start_page=0, end_page=5) + assert node.pruned is False + + def test_pruned_resets(self): + node = TreeNode(node_id="0000", title="Root", summary="Root", + start_page=0, end_page=10) + child = TreeNode(node_id="0001", title="Child", summary="Child", + start_page=0, end_page=5, parent=node) + node.children.append(child) + node.pruned = True + child.pruned = True + node.reset_mcts_state() + assert node.pruned is False + assert child.pruned is False + class TestDocumentIndex: def test_save_and_load(self): @@ -115,3 +132,32 @@ def test_save_and_load(self): assert loaded.root.title == "Doc" finally: tmp_path.unlink() + + +class TestSearchStats: + def test_defaults(self): + stats = SearchStats() + assert stats.iterations_used == 0 + assert stats.converged is False + assert stats.convergence_reason == "" + assert stats.pruned_branches == 0 + assert stats.coverage_pct == 0.0 + + def test_to_dict(self): + stats = SearchStats( + iterations_used=12, iterations_max=50, + converged=True, convergence_reason="top_k_stable", + convergence_iteration=11, + total_nodes=47, visited_nodes=38, + coverage_pct=80.85, + pruned_branches=3, + mean_reward=0.7234, reward_variance=0.0456, + ) + d = stats.to_dict() + assert d["iterations_used"] == 12 + assert d["converged"] is True + assert d["convergence_reason"] == "top_k_stable" + assert d["coverage_pct"] == 80.8 # rounded + assert d["pruned_branches"] == 3 + assert d["mean_reward"] == 0.7234 + assert d["reward_variance"] == 0.0456 diff --git a/treerag/config.py b/treerag/config.py index 3ce3219..e933b92 100644 --- a/treerag/config.py +++ b/treerag/config.py @@ -32,6 +32,32 @@ class MCTSConfig: max_depth: int = 10 parallel_phase2: bool = True + # Adaptive MCTS + adaptive: bool = True + min_iterations: int = 8 + max_iterations: int = 50 + meta_min_iterations: int = 5 + meta_max_iterations: int = 30 + + # Convergence detection + convergence_window: int = 4 + convergence_variance_threshold: float = 0.01 + top_k_stable_rounds: int = 3 + + # Branch pruning + pruning_enabled: bool = True + pruning_min_visits: int = 3 + pruning_reward_threshold: float = 0.25 + + # Adaptive exploration (C decay) + adaptive_exploration: bool = True + exploration_start: float = 2.0 + exploration_end: float = 0.5 + exploration_decay: str = "linear" + + # Batch simulation (parallel LLM calls per iteration) + simulation_batch_size: int = 4 + @dataclass class IndexerConfig: @@ -80,6 +106,22 @@ def from_env(cls) -> "TreeRAGConfig": confidence_threshold=float(os.getenv("CONFIDENCE_THRESHOLD", "0.7")), top_k_documents=int(os.getenv("TOP_K_DOCUMENTS", "3")), parallel_phase2=os.getenv("PARALLEL_PHASE2", "true").lower() == "true", + adaptive=os.getenv("MCTS_ADAPTIVE", "true").lower() == "true", + min_iterations=int(os.getenv("MCTS_MIN_ITERATIONS", "8")), + max_iterations=int(os.getenv("MCTS_MAX_ITERATIONS", "50")), + meta_min_iterations=int(os.getenv("MCTS_META_MIN_ITERATIONS", "5")), + meta_max_iterations=int(os.getenv("MCTS_META_MAX_ITERATIONS", "30")), + convergence_window=int(os.getenv("MCTS_CONVERGENCE_WINDOW", "4")), + convergence_variance_threshold=float(os.getenv("MCTS_CONVERGENCE_VARIANCE_THRESHOLD", "0.01")), + top_k_stable_rounds=int(os.getenv("MCTS_TOP_K_STABLE_ROUNDS", "3")), + pruning_enabled=os.getenv("MCTS_PRUNING", "true").lower() == "true", + pruning_min_visits=int(os.getenv("MCTS_PRUNING_MIN_VISITS", "3")), + pruning_reward_threshold=float(os.getenv("MCTS_PRUNING_REWARD_THRESHOLD", "0.25")), + adaptive_exploration=os.getenv("MCTS_ADAPTIVE_EXPLORATION", "true").lower() == "true", + exploration_start=float(os.getenv("MCTS_EXPLORATION_START", "2.0")), + exploration_end=float(os.getenv("MCTS_EXPLORATION_END", "0.5")), + exploration_decay=os.getenv("MCTS_EXPLORATION_DECAY", "linear"), + simulation_batch_size=int(os.getenv("MCTS_SIMULATION_BATCH_SIZE", "4")), ), indexer=IndexerConfig( batch_size=int(os.getenv("BATCH_SIZE", "15")), diff --git a/treerag/mcts.py b/treerag/mcts.py index a21fd1f..8c9826a 100644 --- a/treerag/mcts.py +++ b/treerag/mcts.py @@ -1,18 +1,25 @@ """ -Two-Phase MCTS Engine. +Two-Phase MCTS Engine with Adaptive Search. Phase 1: Score document summaries → pick top-K docs Phase 2: Search within selected docs → find sections (parallel) + +Adaptive features: +- Dynamic iteration bounds (min/max instead of fixed) +- Multi-signal convergence detection (top-k stability, variance stability, confidence) +- Branch pruning (skip consistently low-scoring subtrees) +- Exploration constant decay (high C → low C over search lifetime) """ import json import math import random +import statistics import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Optional from rich.console import Console from .config import MCTSConfig -from .models import TreeNode, DocumentIndex, FolderIndex, FolderDocEntry, SearchResult, DocumentScore +from .models import TreeNode, DocumentIndex, FolderIndex, FolderDocEntry, SearchResult, DocumentScore, SearchStats from .llm_client import LLMClient console = Console() @@ -38,35 +45,249 @@ def __init__(self, config: MCTSConfig, llm: LLMClient, search_model: str = "gpt- self.llm = llm self.search_model = search_model + # ========================================================================= + # Adaptive helpers + # ========================================================================= + + def _get_exploration_constant(self, iteration: int, max_iterations: int) -> float: + """Compute decaying exploration constant for this iteration.""" + if not self.config.adaptive or not self.config.adaptive_exploration: + return self.config.exploration_constant + + progress = iteration / max(max_iterations - 1, 1) + c_start = self.config.exploration_start + c_end = self.config.exploration_end + + if self.config.exploration_decay == "cosine": + return c_end + (c_start - c_end) * 0.5 * (1 + math.cos(math.pi * progress)) + else: + return c_start + (c_end - c_start) * progress + + def _check_convergence(self, root: TreeNode, iteration: int, tracker: dict) -> tuple[bool, str]: + """Multi-signal convergence detection. + + Signals (any one triggers stop): + 1. Top-k node IDs unchanged for top_k_stable_rounds consecutive iterations + 2. Reward variance stabilized across convergence_window + 3. Original confidence check (3+ leaves with high scores) + """ + if not self.config.adaptive: + if self._should_stop_early_legacy(root): + return True, "confidence_met" + return False, "" + + min_iters = tracker.get("min_iterations", self.config.min_iterations) + if iteration < min_iters: + return False, "" + + # Collect visited, non-pruned nodes + all_nodes = [] + def _collect_visited(n): + if n.visit_count > 0 and not n.pruned: + all_nodes.append(n) + for c in n.children: + if not c.pruned: + _collect_visited(c) + _collect_visited(root) + + # Signal 1: Top-K stability + all_nodes.sort(key=lambda n: n.average_reward * math.log(n.visit_count + 1), reverse=True) + current_top_k = tuple(n.node_id for n in all_nodes[:self.config.top_k_results]) + + prev_top_k = tracker.get("prev_top_k") + if prev_top_k == current_top_k: + tracker["top_k_stable_count"] = tracker.get("top_k_stable_count", 0) + 1 + else: + tracker["top_k_stable_count"] = 0 + tracker["prev_top_k"] = current_top_k + + if tracker["top_k_stable_count"] >= self.config.top_k_stable_rounds: + return True, "top_k_stable" + + # Signal 2: Variance stability + rewards = [n.average_reward for n in all_nodes if n.visit_count > 0] + if len(rewards) >= 2: + current_variance = statistics.variance(rewards) + variance_history = tracker.setdefault("variance_history", []) + variance_history.append(current_variance) + + window = self.config.convergence_window + if len(variance_history) >= window * 2: + old_window = variance_history[-(window * 2):-window] + new_window = variance_history[-window:] + old_mean = sum(old_window) / len(old_window) + new_mean = sum(new_window) / len(new_window) + if abs(new_mean - old_mean) < self.config.convergence_variance_threshold: + return True, "variance_stable" + + # Signal 3: Original confidence check + leaves = self._get_all_leaves(root) + confident = [l for l in leaves if l.visit_count >= 3 + and l.average_reward >= self.config.confidence_threshold + and not l.pruned] + if len(confident) >= self.config.top_k_results: + return True, "confidence_met" + + return False, "" + + def _apply_virtual_loss(self, node: TreeNode): + """Add a fake visit with 0 reward to discourage re-selection in the same batch.""" + current = node + while current is not None: + current.visit_count += 1 + current = current.parent + + def _remove_virtual_loss(self, node: TreeNode): + """Remove the fake visit added by _apply_virtual_loss.""" + current = node + while current is not None: + current.visit_count -= 1 + current = current.parent + + def _simulate_batch(self, query: str, targets: list[TreeNode]) -> list[float]: + """Run multiple simulations in parallel using threads.""" + if len(targets) == 1: + try: + return [self._simulate_section(query, targets[0])] + except Exception: + return [0.3] + + results = [0.3] * len(targets) + with ThreadPoolExecutor(max_workers=len(targets)) as executor: + futures = {executor.submit(self._simulate_section, query, t): i for i, t in enumerate(targets)} + for future in as_completed(futures): + idx = futures[future] + try: + results[idx] = future.result(timeout=30) + except Exception: + results[idx] = 0.3 + return results + + def _prune_branches(self, root: TreeNode, iteration: int, stats: SearchStats): + """Mark consistently low-scoring internal nodes as pruned.""" + if not self.config.adaptive or not self.config.pruning_enabled: + return + + def _prune_recursive(node: TreeNode): + if node.pruned: + return + if (node.visit_count >= self.config.pruning_min_visits + and node.average_reward < self.config.pruning_reward_threshold + and not node.is_leaf): + node.pruned = True + stats.pruned_branches += 1 + stats.pruned_at_iterations.append(iteration) + # Mark entire subtree + def _mark_subtree(n): + n.pruned = True + for c in n.children: + _mark_subtree(c) + for c in node.children: + _mark_subtree(c) + return + for c in node.children: + _prune_recursive(c) + + _prune_recursive(root) + # ========================================================================= # Phase 1: Meta-tree search # ========================================================================= - def search_meta(self, query: str, folder_index: FolderIndex, verbose: bool = True) -> list[DocumentScore]: + def search_meta(self, query: str, folder_index: FolderIndex, verbose: bool = True) -> tuple[list[DocumentScore], SearchStats]: if not folder_index.documents: - return [] + return [], SearchStats() folder_index.reset_mcts_state() total_visits = 0 - iterations = min(self.config.meta_iterations, len(folder_index.documents) * 5) + num_docs = len(folder_index.documents) + + if self.config.adaptive: + min_iter = min(self.config.meta_min_iterations, num_docs * 2) + max_iter = min(self.config.meta_max_iterations, num_docs * 5) + else: + max_iter = min(self.config.meta_iterations, num_docs * 5) + min_iter = max_iter + + stats = SearchStats( + iterations_max=max_iter, + total_nodes=num_docs, + exploration_start=self.config.exploration_start if self.config.adaptive_exploration else self.config.exploration_constant, + ) if verbose: console.print(f"\n[bold cyan]Phase 1 — Document Selection[/bold cyan]") - console.print(f" Scoring {folder_index.total_documents} documents ({iterations} iterations)") + console.print(f" Scoring {num_docs} documents (up to {max_iter} iterations)") + + convergence_tracker = {"min_iterations": min_iter} + + for i in range(max_iter): + c = self._get_exploration_constant(i, max_iter) + + active_docs = [d for d in folder_index.documents if not d.pruned] + if not active_docs: + break - for _ in range(iterations): - selected = max(folder_index.documents, key=lambda d: d.ucb1(total_visits, self.config.exploration_constant)) + selected = max(active_docs, key=lambda d: d.ucb1(total_visits, c)) try: score = self._simulate_document(query, selected) except Exception: - score = 0.3 # Neutral on LLM failure + score = 0.3 selected.visit_count += 1 selected.total_reward += score total_visits += 1 + # Prune low-scoring docs + if self.config.adaptive and self.config.pruning_enabled and i >= min_iter and i % 3 == 0: + for d in folder_index.documents: + if (not d.pruned + and d.visit_count >= self.config.pruning_min_visits + and d.average_reward < self.config.pruning_reward_threshold): + d.pruned = True + stats.pruned_branches += 1 + stats.pruned_at_iterations.append(i) + + # Convergence: top-k doc stability + if self.config.adaptive and i >= min_iter: + scored_docs = sorted( + [d for d in folder_index.documents if d.visit_count > 0 and not d.pruned], + key=lambda d: d.average_reward, reverse=True, + ) + current_top = tuple(d.document_id for d in scored_docs[:self.config.top_k_documents]) + prev_top = convergence_tracker.get("prev_top_docs") + if prev_top == current_top: + convergence_tracker["doc_stable_count"] = convergence_tracker.get("doc_stable_count", 0) + 1 + else: + convergence_tracker["doc_stable_count"] = 0 + convergence_tracker["prev_top_docs"] = current_top + + if convergence_tracker["doc_stable_count"] >= self.config.top_k_stable_rounds: + stats.converged = True + stats.convergence_reason = "top_k_stable" + stats.convergence_iteration = i + stats.iterations_used = i + 1 + if verbose: + console.print(f" [dim]Converged at iteration {i+1}/{max_iter} ({stats.convergence_reason})[/dim]") + break + else: + stats.iterations_used = max_iter + stats.convergence_reason = "max_reached" + + if not stats.iterations_used: + stats.iterations_used = total_visits + + # Final stats + visited_docs = [d for d in folder_index.documents if d.visit_count > 0] + stats.visited_nodes = len(visited_docs) + stats.coverage_pct = (len(visited_docs) / num_docs * 100) if num_docs else 0.0 + if visited_docs: + rewards = [d.average_reward for d in visited_docs] + stats.mean_reward = statistics.mean(rewards) + stats.reward_variance = statistics.variance(rewards) if len(rewards) > 1 else 0.0 + scored = sorted( [DocumentScore(entry=d, relevance_score=d.average_reward, visit_count=d.visit_count) - for d in folder_index.documents if d.visit_count > 0], + for d in folder_index.documents if d.visit_count > 0 and not d.pruned], key=lambda x: x.relevance_score, reverse=True, ) top_docs = scored[:self.config.top_k_documents] @@ -74,7 +295,11 @@ def search_meta(self, query: str, folder_index: FolderIndex, verbose: bool = Tru if verbose: for d in top_docs: console.print(f" → {d.entry.filename} (score: {d.relevance_score:.3f}, visits: {d.visit_count})") - return top_docs + if stats.converged: + console.print(f" [dim]{stats.iterations_used}/{max_iter} iters, {stats.pruned_branches} docs pruned[/dim]") + + stats.exploration_end = c if total_visits > 0 else stats.exploration_start + return top_docs, stats def _simulate_document(self, query: str, doc: FolderDocEntry) -> float: prompt = f"Query: {query}\nDocument: {doc.filename} ({doc.total_pages} pages)\nSummary: {doc.summary}\nKeywords: {', '.join(doc.keywords) if doc.keywords else 'none'}\nRelevance?" @@ -88,32 +313,106 @@ def _simulate_document(self, query: str, doc: FolderDocEntry) -> float: # Phase 2: Per-document search # ========================================================================= - def search_document(self, query: str, doc_index: DocumentIndex, verbose: bool = True) -> list[SearchResult]: + def search_document(self, query: str, doc_index: DocumentIndex, verbose: bool = True) -> tuple[list[SearchResult], SearchStats]: if not doc_index.root: - return [] + return [], SearchStats() + root = doc_index.root root.reset_mcts_state() + if self.config.adaptive: + min_iter = self.config.min_iterations + max_iter = self.config.max_iterations + else: + min_iter = self.config.iterations + max_iter = self.config.iterations + + all_nodes = doc_index.get_all_nodes() + stats = SearchStats( + iterations_max=max_iter, + total_nodes=len(all_nodes), + exploration_start=self.config.exploration_start if self.config.adaptive_exploration else self.config.exploration_constant, + ) + + batch_size = self.config.simulation_batch_size + if verbose: - console.print(f"\n [dim]Searching: {doc_index.filename}[/dim]") + mode = f"batch={batch_size}" if batch_size > 1 else "sequential" + console.print(f"\n [dim]Searching: {doc_index.filename} (up to {max_iter} iterations, {mode})[/dim]") - for _ in range(self.config.iterations): - selected = self._select(root) - expanded = self._expand(selected) - target = expanded or selected - try: - reward = self._simulate_section(query, target) - except Exception: - reward = 0.3 - self._backpropagate(target, reward) - if self._should_stop_early(root): + convergence_tracker = {"min_iterations": min_iter} + c = self.config.exploration_constant + i = 0 + converged = False + + while i < max_iter: + c = self._get_exploration_constant(i, max_iter) + + # Select up to batch_size nodes using virtual loss for diversity + batch_targets = [] + remaining = min(batch_size, max_iter - i) + for _ in range(remaining): + selected = self._select(root, exploration_c=c) + expanded = self._expand(selected) + target = expanded or selected + + if target.pruned: + break + + self._apply_virtual_loss(target) + batch_targets.append(target) + + if not batch_targets: break - return self._collect_results(root, doc_index.filename, doc_index.document_id) + # Parallel simulate all targets + rewards = self._simulate_batch(query, batch_targets) + + # Remove virtual loss, then backpropagate real rewards + for target, reward in zip(batch_targets, rewards): + self._remove_virtual_loss(target) + self._backpropagate(target, reward) + + i += len(batch_targets) + + # Periodic pruning + if i >= min_iter and i % 5 == 0: + self._prune_branches(root, i, stats) + + # Convergence check + should_stop, reason = self._check_convergence(root, i, convergence_tracker) + if should_stop: + stats.converged = True + stats.convergence_reason = reason + stats.convergence_iteration = i + stats.iterations_used = i + if verbose: + console.print(f" [dim]Converged at iteration {i}/{max_iter} ({reason})[/dim]") + converged = True + break + + if not converged: + stats.iterations_used = i if i else max_iter + stats.convergence_reason = "max_reached" + + # Final stats + visited = [n for n in all_nodes if n.visit_count > 0] + stats.visited_nodes = len(visited) + stats.coverage_pct = (len(visited) / len(all_nodes) * 100) if all_nodes else 0.0 + stats.exploration_end = c - def search_documents_parallel(self, query: str, doc_indices: list[DocumentIndex], verbose: bool = True) -> list[SearchResult]: + if visited: + rewards = [n.average_reward for n in visited] + stats.mean_reward = statistics.mean(rewards) + stats.reward_variance = statistics.variance(rewards) if len(rewards) > 1 else 0.0 + + results = self._collect_results(root, doc_index.filename, doc_index.document_id) + return results, stats + + def search_documents_parallel(self, query: str, doc_indices: list[DocumentIndex], verbose: bool = True) -> tuple[list[SearchResult], list[SearchStats]]: """Phase 2 parallel. Gracefully handles per-document failures.""" all_results = [] + all_stats = [] if verbose: console.print(f"\n[bold cyan]Phase 2 — Searching {len(doc_indices)} documents (parallel)[/bold cyan]") @@ -122,49 +421,60 @@ def search_documents_parallel(self, query: str, doc_indices: list[DocumentIndex] for future in as_completed(futures): doc_idx = futures[future] try: - results = future.result(timeout=120) # 2 min timeout per doc + results, stats = future.result(timeout=120) if verbose: - console.print(f" {doc_idx.filename}: {len(results)} sections found") + status = f"converged@{stats.convergence_iteration}" if stats.converged else f"{stats.iterations_used}iters" + console.print(f" {doc_idx.filename}: {len(results)} sections ({status}, {stats.pruned_branches} pruned)") all_results.extend(results) + all_stats.append(stats) except TimeoutError: console.print(f" [yellow]{doc_idx.filename}: Timed out — skipping[/yellow]") + all_stats.append(SearchStats(convergence_reason="timeout")) except Exception as e: console.print(f" [red]{doc_idx.filename}: Error — {e}[/red]") + all_stats.append(SearchStats(convergence_reason="error")) all_results.sort(key=lambda r: r.relevance_score, reverse=True) - return all_results[:self.config.top_k_results] + return all_results[:self.config.top_k_results], all_stats - def search_documents_sequential(self, query: str, doc_indices: list[DocumentIndex], verbose: bool = True) -> list[SearchResult]: + def search_documents_sequential(self, query: str, doc_indices: list[DocumentIndex], verbose: bool = True) -> tuple[list[SearchResult], list[SearchStats]]: all_results = [] + all_stats = [] if verbose: console.print(f"\n[bold cyan]Phase 2 — Searching {len(doc_indices)} documents[/bold cyan]") for idx in doc_indices: try: - results = self.search_document(query, idx, verbose) + results, stats = self.search_document(query, idx, verbose) all_results.extend(results) + all_stats.append(stats) except Exception as e: console.print(f" [red]{idx.filename}: Error — {e}[/red]") + all_stats.append(SearchStats(convergence_reason="error")) all_results.sort(key=lambda r: r.relevance_score, reverse=True) - return all_results[:self.config.top_k_results] + return all_results[:self.config.top_k_results], all_stats # ========================================================================= # Core MCTS # ========================================================================= - def _select(self, node: TreeNode) -> TreeNode: + def _select(self, node: TreeNode, exploration_c: float = None) -> TreeNode: + c = exploration_c if exploration_c is not None else self.config.exploration_constant current, depth = node, 0 while not current.is_leaf and depth < self.config.max_depth: - unvisited = [c for c in current.children if c.visit_count == 0] + active_children = [ch for ch in current.children if not ch.pruned] + if not active_children: + break + unvisited = [ch for ch in active_children if ch.visit_count == 0] if unvisited: return current - current = max(current.children, key=lambda c: c.ucb1(self.config.exploration_constant)) + current = max(active_children, key=lambda ch: ch.ucb1(c)) depth += 1 return current def _expand(self, node: TreeNode) -> Optional[TreeNode]: if node.is_leaf: return None - unvisited = [c for c in node.children if c.visit_count == 0] + unvisited = [c for c in node.children if c.visit_count == 0 and not c.pruned] return random.choice(unvisited) if unvisited else None def _simulate_section(self, query: str, node: TreeNode) -> float: @@ -182,14 +492,15 @@ def _backpropagate(self, node: TreeNode, reward: float): current.total_reward += reward current = current.parent - def _should_stop_early(self, root: TreeNode) -> bool: + def _should_stop_early_legacy(self, root: TreeNode) -> bool: + """Original early stopping logic, used when adaptive=False.""" leaves = self._get_all_leaves(root) return len([l for l in leaves if l.visit_count >= 3 and l.average_reward >= self.config.confidence_threshold]) >= self.config.top_k_results def _collect_results(self, root, doc_filename="", doc_id=""): all_nodes = [] def _collect(n): - if n.visit_count > 0: + if n.visit_count > 0 and not n.pruned: all_nodes.append(n) for c in n.children: _collect(c) diff --git a/treerag/models.py b/treerag/models.py index 112f356..755bf48 100644 --- a/treerag/models.py +++ b/treerag/models.py @@ -61,6 +61,7 @@ class TreeNode: visit_count: int = 0 total_reward: float = 0.0 parent: Optional["TreeNode"] = None + pruned: bool = False @property def average_reward(self) -> float: @@ -80,6 +81,7 @@ def ucb1(self, exploration_constant: float = 1.414) -> float: def reset_mcts_state(self): self.visit_count = 0 self.total_reward = 0.0 + self.pruned = False for c in self.children: c.reset_mcts_state() @@ -208,6 +210,7 @@ class FolderDocEntry: visit_count: int = 0 total_reward: float = 0.0 + pruned: bool = False @property def average_reward(self) -> float: @@ -227,6 +230,7 @@ def ucb1(self, parent_visits: int, exploration_constant: float = 1.414) -> float def reset_mcts_state(self): self.visit_count = 0 self.total_reward = 0.0 + self.pruned = False def to_dict(self) -> dict: return { @@ -372,6 +376,40 @@ class DocumentScore: visit_count: int +@dataclass +class SearchStats: + """Detailed statistics about an adaptive MCTS search run.""" + iterations_used: int = 0 + iterations_max: int = 0 + converged: bool = False + convergence_reason: str = "" + convergence_iteration: int = 0 + total_nodes: int = 0 + visited_nodes: int = 0 + coverage_pct: float = 0.0 + pruned_branches: int = 0 + pruned_at_iterations: list[int] = field(default_factory=list) + exploration_start: float = 0.0 + exploration_end: float = 0.0 + mean_reward: float = 0.0 + reward_variance: float = 0.0 + + def to_dict(self) -> dict: + return { + "iterations_used": self.iterations_used, + "iterations_max": self.iterations_max, + "converged": self.converged, + "convergence_reason": self.convergence_reason, + "convergence_iteration": self.convergence_iteration, + "total_nodes": self.total_nodes, + "visited_nodes": self.visited_nodes, + "coverage_pct": round(self.coverage_pct, 1), + "pruned_branches": self.pruned_branches, + "mean_reward": round(self.mean_reward, 4), + "reward_variance": round(self.reward_variance, 4), + } + + @dataclass class QueryResult: query: str @@ -383,3 +421,5 @@ class QueryResult: latency_seconds: float = 0.0 phase1_time: float = 0.0 phase2_time: float = 0.0 + phase1_stats: Optional["SearchStats"] = None + phase2_stats: list["SearchStats"] = field(default_factory=list) diff --git a/treerag/pipeline.py b/treerag/pipeline.py index df2cdc7..fa2e94a 100644 --- a/treerag/pipeline.py +++ b/treerag/pipeline.py @@ -14,7 +14,7 @@ from dataclasses import dataclass, field from .config import TreeRAGConfig -from .models import DocumentIndex, FolderIndex, SearchResult, QueryResult +from .models import DocumentIndex, FolderIndex, SearchResult, QueryResult, SearchStats from .llm_client import LLMClient from .indexer import Indexer from .mcts import MCTSEngine, SYSTEM_PROMPT_DEEP_READ, SYSTEM_PROMPT_ANSWER @@ -210,8 +210,11 @@ def chat( "phase2_time": f"{result.phase2_time:.1f}s", "total_time": f"{total_time:.1f}s", "llm_calls": total_calls, + "mcts_iterations": result.total_mcts_iterations, "cost": f"${self.llm._estimate_cost():.4f}", "docs_searched": result.documents_searched, + "phase1_search_stats": result.phase1_stats.to_dict() if result.phase1_stats else None, + "phase2_search_stats": [s.to_dict() for s in result.phase2_stats] if result.phase2_stats else [], }, ) @@ -238,13 +241,14 @@ def query_folder(self, query: str, folder_name: str, use_vision: bool = True) -> # Phase 1 p1_start = time.time() - doc_scores = self.mcts.search_meta(query, folder_index) + doc_scores, phase1_stats = self.mcts.search_meta(query, folder_index) p1_time = time.time() - p1_start if not doc_scores: return QueryResult( query=query, answer="No relevant documents found in this folder.", latency_seconds=time.time() - start_time, phase1_time=p1_time, + phase1_stats=phase1_stats, ) # Load selected document indices (gracefully skips broken ones) @@ -257,14 +261,15 @@ def query_folder(self, query: str, folder_name: str, use_vision: bool = True) -> answer="Selected documents have missing or corrupt indices. Run 'folder refresh' to repair.", documents_searched=[e.filename for e in selected_entries], latency_seconds=time.time() - start_time, phase1_time=p1_time, + phase1_stats=phase1_stats, ) # Phase 2 p2_start = time.time() if self.config.mcts.parallel_phase2 and len(doc_indices) > 1: - search_results = self.mcts.search_documents_parallel(query, doc_indices) + search_results, phase2_stats = self.mcts.search_documents_parallel(query, doc_indices) else: - search_results = self.mcts.search_documents_sequential(query, doc_indices) + search_results, phase2_stats = self.mcts.search_documents_sequential(query, doc_indices) p2_time = time.time() - p2_start if not search_results: @@ -273,6 +278,7 @@ def query_folder(self, query: str, folder_name: str, use_vision: bool = True) -> documents_searched=[e.filename for e in selected_entries], latency_seconds=time.time() - start_time, phase1_time=p1_time, phase2_time=p2_time, + phase1_stats=phase1_stats, phase2_stats=phase2_stats, ) # Deep Read @@ -283,13 +289,15 @@ def query_folder(self, query: str, folder_name: str, use_vision: bool = True) -> console.print("\n[bold]Generating answer...[/bold]") answer = self._generate_answer(query, search_results) + total_iters = (phase1_stats.iterations_used if phase1_stats else 0) + sum(s.iterations_used for s in phase2_stats) result = QueryResult( query=query, answer=answer, sources=search_results, documents_searched=[e.filename for e in selected_entries], - total_mcts_iterations=self.config.mcts.meta_iterations + self.config.mcts.iterations * len(doc_indices), + total_mcts_iterations=total_iters, total_llm_calls=self.llm.total_calls - llm_before, latency_seconds=time.time() - start_time, phase1_time=p1_time, phase2_time=p2_time, + phase1_stats=phase1_stats, phase2_stats=phase2_stats, ) self._print_result(result) return result @@ -307,13 +315,14 @@ def query_document(self, query: str, doc_index: DocumentIndex, use_vision: bool title="Query", border_style="cyan", )) - search_results = self.mcts.search_document(query, doc_index) + search_results, doc_stats = self.mcts.search_document(query, doc_index) if not search_results: return QueryResult( query=query, answer="No relevant sections found.", documents_searched=[doc_index.filename], latency_seconds=time.time() - start_time, + phase2_stats=[doc_stats], ) console.print("\n[bold]Deep reading...[/bold]") @@ -325,8 +334,10 @@ def query_document(self, query: str, doc_index: DocumentIndex, use_vision: bool result = QueryResult( query=query, answer=answer, sources=search_results, documents_searched=[doc_index.filename], + total_mcts_iterations=doc_stats.iterations_used, total_llm_calls=self.llm.total_calls - llm_before, latency_seconds=time.time() - start_time, + phase2_stats=[doc_stats], ) self._print_result(result) return result @@ -462,10 +473,20 @@ def _print_result(self, result): console.print(table) timing = f"Phase 1: {result.phase1_time:.1f}s | Phase 2: {result.phase2_time:.1f}s | " if result.phase1_time else "" - console.print(f"\n[dim]{timing}Total: {result.latency_seconds:.1f}s | LLM calls: {result.total_llm_calls} | Cost: ~${self.llm._estimate_cost():.4f}[/dim]") + console.print(f"\n[dim]{timing}Total: {result.latency_seconds:.1f}s | LLM calls: {result.total_llm_calls} | Iterations: {result.total_mcts_iterations} | Cost: ~${self.llm._estimate_cost():.4f}[/dim]") if result.documents_searched: console.print(f"[dim]Searched: {', '.join(result.documents_searched)}[/dim]") + # Adaptive search stats + if result.phase1_stats and result.phase1_stats.iterations_used > 0: + s = result.phase1_stats + conv = f"converged@{s.convergence_iteration+1} ({s.convergence_reason})" if s.converged else f"{s.iterations_used} iters" + console.print(f"[dim]Phase 1: {conv}, {s.pruned_branches} docs pruned, coverage: {s.coverage_pct:.0f}%[/dim]") + for i, s in enumerate(result.phase2_stats): + if s.iterations_used > 0: + conv = f"converged@{s.convergence_iteration+1} ({s.convergence_reason})" if s.converged else f"{s.iterations_used} iters" + console.print(f"[dim]Phase 2 doc {i+1}: {conv}, coverage: {s.coverage_pct:.0f}%, pruned: {s.pruned_branches}[/dim]") + def _set_parent_refs(self, node): for c in node.children: c.parent = node