Skip to content

Commit 21cd084

Browse files
committed
feat(copilot): MCTS dynamic strategy engine with reward model and rollout simulation
- Add MCTSConfig, MCTSNode, StrategyRecommendation data structures - Add RewardModel with cosine similarity scoring (R = W1·Match + W2·Safe - W3·Risk) - Add SimulationEngine with 3-level degradation rollout - Add MCTSEngine with PUCT selection, LLM expansion, backpropagation - Integrate MCTS into copilot WebSocket session (feature-flagged, off by default) - Add 11 mcts_* settings to config and rollout LLM provider - Add user-facing docs and 39 unit tests
1 parent 9aeb1e1 commit 21cd084

File tree

10 files changed

+1479
-0
lines changed

10 files changed

+1479
-0
lines changed

backend/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ class Settings(BaseSettings):
3636
nls_access_key_id: str = ""
3737
nls_access_key_secret: str = ""
3838

39+
# MCTS 动态策略引擎
40+
mcts_enabled: bool = False
41+
mcts_iterations: int = 8
42+
mcts_branch_factor: int = 3
43+
mcts_rollout_depth: int = 2
44+
mcts_c_puct: float = 1.4
45+
mcts_max_tree_nodes: int = 200
46+
mcts_merge_threshold: float = 0.8
47+
mcts_search_timeout: float = 25.0
48+
mcts_rollout_api_base: str = ""
49+
mcts_rollout_api_key: str = ""
50+
mcts_rollout_model: str = ""
51+
3952
# Copilot — Tavily Web Search
4053
tavily_api_key: str = ""
4154

backend/copilot/mcts_config.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""MCTS 相关配置和共享数据结构。"""
2+
from __future__ import annotations
3+
4+
import uuid
5+
from dataclasses import dataclass, field
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
from backend.config import Settings
10+
11+
12+
@dataclass
13+
class MCTSConfig:
14+
"""MCTS 搜索参数配置。"""
15+
enabled: bool = True
16+
iterations: int = 8
17+
branch_factor: int = 3
18+
rollout_depth: int = 2
19+
c_puct: float = 1.4
20+
max_tree_nodes: int = 200
21+
merge_threshold: float = 0.8
22+
search_timeout: float = 25.0
23+
24+
@classmethod
25+
def from_settings(cls, s: Settings) -> MCTSConfig:
26+
return cls(
27+
enabled=s.mcts_enabled,
28+
iterations=s.mcts_iterations,
29+
branch_factor=s.mcts_branch_factor,
30+
rollout_depth=s.mcts_rollout_depth,
31+
c_puct=s.mcts_c_puct,
32+
max_tree_nodes=s.mcts_max_tree_nodes,
33+
merge_threshold=s.mcts_merge_threshold,
34+
search_timeout=s.mcts_search_timeout,
35+
)
36+
37+
38+
@dataclass
39+
class MCTSNode:
40+
"""博弈树节点。"""
41+
id: str
42+
parent_id: str | None = None
43+
children: list[str] = field(default_factory=list)
44+
45+
# 状态
46+
actor: str = "candidate" # "candidate" | "hr"
47+
action: str = ""
48+
action_embedding: list[float] | None = None
49+
50+
# 上下文
51+
topic: str = ""
52+
conversation_snapshot: list[dict] = field(default_factory=list)
53+
54+
# MCTS 统计量
55+
visit_count: int = 0
56+
total_reward: float = 0.0
57+
prior: float = 0.5
58+
59+
# 元数据
60+
strategy_tree_node_id: str | None = None
61+
risk_level: str = "safe"
62+
depth: int = 0
63+
64+
@property
65+
def q_value(self) -> float:
66+
if self.visit_count == 0:
67+
return 0.0
68+
return self.total_reward / self.visit_count
69+
70+
@staticmethod
71+
def make_id() -> str:
72+
return uuid.uuid4().hex[:10]
73+
74+
75+
@dataclass
76+
class StrategyRecommendation:
77+
"""MCTS 搜索结果,推送给前端的推荐。"""
78+
optimal_response_strategy: str = ""
79+
predicted_followups: list[dict] = field(default_factory=list)
80+
danger_zones: list[str] = field(default_factory=list)
81+
win_rate: float = 0.5
82+
best_path: list[dict] = field(default_factory=list)
83+
confidence: float = 0.0
84+
iterations_completed: int = 0

0 commit comments

Comments
 (0)