Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,24 @@ PARALLEL_PHASE2=true
BATCH_SIZE=15
MAX_PAGES=5000

# MCTS — Adaptive
MCTS_ADAPTIVE=true
MCTS_MIN_ITERATIONS=8
MCTS_MAX_ITERATIONS=50
MCTS_META_MIN_ITERATIONS=5
MCTS_META_MAX_ITERATIONS=30
MCTS_CONVERGENCE_WINDOW=4
MCTS_CONVERGENCE_VARIANCE_THRESHOLD=0.01
MCTS_TOP_K_STABLE_ROUNDS=3
MCTS_PRUNING=true
MCTS_PRUNING_MIN_VISITS=3
MCTS_PRUNING_REWARD_THRESHOLD=0.25
MCTS_ADAPTIVE_EXPLORATION=true
MCTS_EXPLORATION_START=2.0
MCTS_EXPLORATION_END=0.5
MCTS_EXPLORATION_DECAY=linear
MCTS_SIMULATION_BATCH_SIZE=4

# Storage
TREERAG_DATA_DIR=.treerag_data
AUTO_REINDEX=true
43 changes: 43 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@ def test_defaults(self):
assert config.top_k_documents == 3
assert config.parallel_phase2 is True

def test_adaptive_defaults(self):
config = MCTSConfig()
assert config.adaptive is True
assert config.min_iterations == 8
assert config.max_iterations == 50
assert config.meta_min_iterations == 5
assert config.meta_max_iterations == 30
assert config.convergence_window == 4
assert config.convergence_variance_threshold == 0.01
assert config.top_k_stable_rounds == 3
assert config.pruning_enabled is True
assert config.pruning_min_visits == 3
assert config.pruning_reward_threshold == 0.25
assert config.adaptive_exploration is True
assert config.exploration_start == 2.0
assert config.exploration_end == 0.5
assert config.exploration_decay == "linear"
assert config.simulation_batch_size == 4

def test_adaptive_disabled(self):
config = MCTSConfig(adaptive=False)
assert config.adaptive is False
# Other adaptive defaults still set
assert config.min_iterations == 8
assert config.max_iterations == 50


class TestIndexerConfig:
def test_defaults(self):
Expand Down Expand Up @@ -71,3 +97,20 @@ def test_from_env_custom(self):
assert config.mcts.iterations == 40
assert config.indexer.batch_size == 10
assert config.folder.base_dir == "/tmp/test_data"

def test_from_env_adaptive(self):
env = {
"MCTS_ADAPTIVE": "false",
"MCTS_MIN_ITERATIONS": "10",
"MCTS_MAX_ITERATIONS": "60",
"MCTS_PRUNING": "false",
"MCTS_EXPLORATION_START": "1.8",
}
with patch.dict(os.environ, env):
config = TreeRAGConfig.from_env()
assert config.mcts.adaptive is False
assert config.mcts.min_iterations == 10
assert config.mcts.max_iterations == 60
assert config.mcts.pruning_enabled is False
assert config.mcts.exploration_start == 1.8
assert config.mcts.simulation_batch_size == 4 # default unchanged
48 changes: 47 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
from pathlib import Path

from treerag.models import TreeNode, DocumentIndex
from treerag.models import TreeNode, DocumentIndex, SearchStats


class TestTreeNode:
Expand Down Expand Up @@ -93,6 +93,23 @@ def test_reset_mcts_state(self):
assert child.visit_count == 0
assert child.total_reward == 0.0

def test_pruned_default(self):
node = TreeNode(node_id="0001", title="Test", summary="Test",
start_page=0, end_page=5)
assert node.pruned is False

def test_pruned_resets(self):
node = TreeNode(node_id="0000", title="Root", summary="Root",
start_page=0, end_page=10)
child = TreeNode(node_id="0001", title="Child", summary="Child",
start_page=0, end_page=5, parent=node)
node.children.append(child)
node.pruned = True
child.pruned = True
node.reset_mcts_state()
assert node.pruned is False
assert child.pruned is False


class TestDocumentIndex:
def test_save_and_load(self):
Expand All @@ -115,3 +132,32 @@ def test_save_and_load(self):
assert loaded.root.title == "Doc"
finally:
tmp_path.unlink()


class TestSearchStats:
def test_defaults(self):
stats = SearchStats()
assert stats.iterations_used == 0
assert stats.converged is False
assert stats.convergence_reason == ""
assert stats.pruned_branches == 0
assert stats.coverage_pct == 0.0

def test_to_dict(self):
stats = SearchStats(
iterations_used=12, iterations_max=50,
converged=True, convergence_reason="top_k_stable",
convergence_iteration=11,
total_nodes=47, visited_nodes=38,
coverage_pct=80.85,
pruned_branches=3,
mean_reward=0.7234, reward_variance=0.0456,
)
d = stats.to_dict()
assert d["iterations_used"] == 12
assert d["converged"] is True
assert d["convergence_reason"] == "top_k_stable"
assert d["coverage_pct"] == 80.8 # rounded
assert d["pruned_branches"] == 3
assert d["mean_reward"] == 0.7234
assert d["reward_variance"] == 0.0456
42 changes: 42 additions & 0 deletions treerag/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,32 @@ class MCTSConfig:
max_depth: int = 10
parallel_phase2: bool = True

# Adaptive MCTS
adaptive: bool = True
min_iterations: int = 8
max_iterations: int = 50
meta_min_iterations: int = 5
meta_max_iterations: int = 30

# Convergence detection
convergence_window: int = 4
convergence_variance_threshold: float = 0.01
top_k_stable_rounds: int = 3

# Branch pruning
pruning_enabled: bool = True
pruning_min_visits: int = 3
pruning_reward_threshold: float = 0.25

# Adaptive exploration (C decay)
adaptive_exploration: bool = True
exploration_start: float = 2.0
exploration_end: float = 0.5
exploration_decay: str = "linear"

# Batch simulation (parallel LLM calls per iteration)
simulation_batch_size: int = 4


@dataclass
class IndexerConfig:
Expand Down Expand Up @@ -80,6 +106,22 @@ def from_env(cls) -> "TreeRAGConfig":
confidence_threshold=float(os.getenv("CONFIDENCE_THRESHOLD", "0.7")),
top_k_documents=int(os.getenv("TOP_K_DOCUMENTS", "3")),
parallel_phase2=os.getenv("PARALLEL_PHASE2", "true").lower() == "true",
adaptive=os.getenv("MCTS_ADAPTIVE", "true").lower() == "true",
min_iterations=int(os.getenv("MCTS_MIN_ITERATIONS", "8")),
max_iterations=int(os.getenv("MCTS_MAX_ITERATIONS", "50")),
meta_min_iterations=int(os.getenv("MCTS_META_MIN_ITERATIONS", "5")),
meta_max_iterations=int(os.getenv("MCTS_META_MAX_ITERATIONS", "30")),
convergence_window=int(os.getenv("MCTS_CONVERGENCE_WINDOW", "4")),
convergence_variance_threshold=float(os.getenv("MCTS_CONVERGENCE_VARIANCE_THRESHOLD", "0.01")),
top_k_stable_rounds=int(os.getenv("MCTS_TOP_K_STABLE_ROUNDS", "3")),
pruning_enabled=os.getenv("MCTS_PRUNING", "true").lower() == "true",
pruning_min_visits=int(os.getenv("MCTS_PRUNING_MIN_VISITS", "3")),
pruning_reward_threshold=float(os.getenv("MCTS_PRUNING_REWARD_THRESHOLD", "0.25")),
adaptive_exploration=os.getenv("MCTS_ADAPTIVE_EXPLORATION", "true").lower() == "true",
exploration_start=float(os.getenv("MCTS_EXPLORATION_START", "2.0")),
exploration_end=float(os.getenv("MCTS_EXPLORATION_END", "0.5")),
exploration_decay=os.getenv("MCTS_EXPLORATION_DECAY", "linear"),
simulation_batch_size=int(os.getenv("MCTS_SIMULATION_BATCH_SIZE", "4")),
),
indexer=IndexerConfig(
batch_size=int(os.getenv("BATCH_SIZE", "15")),
Expand Down
Loading
Loading