Skip to content

Commit 40b6cc1

Browse files
committed
feat(copilot): MCTS dynamic strategy engine with reward model and rollout simulation
- Add MCTSConfig, MCTSNode, StrategyRecommendation data structures - Add RewardModel with cosine similarity scoring (R = W1·Match + W2·Safe - W3·Risk) - Add SimulationEngine with 3-level degradation rollout - Add MCTSEngine with PUCT selection, LLM expansion, backpropagation - Integrate MCTS into copilot WebSocket session (feature-flagged, off by default) - Add 11 mcts_* settings to config and rollout LLM provider - Add user-facing docs and 39 unit tests
1 parent 9aeb1e1 commit 40b6cc1

File tree

12 files changed

+1496
-11
lines changed

12 files changed

+1496
-11
lines changed

backend/config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,19 @@ class Settings(BaseSettings):
3636
nls_access_key_id: str = ""
3737
nls_access_key_secret: str = ""
3838

39+
# MCTS 动态策略引擎
40+
mcts_enabled: bool = False
41+
mcts_iterations: int = 8
42+
mcts_branch_factor: int = 3
43+
mcts_rollout_depth: int = 2
44+
mcts_c_puct: float = 1.4
45+
mcts_max_tree_nodes: int = 200
46+
mcts_merge_threshold: float = 0.8
47+
mcts_search_timeout: float = 25.0
48+
mcts_rollout_api_base: str = ""
49+
mcts_rollout_api_key: str = ""
50+
mcts_rollout_model: str = ""
51+
3952
# Copilot — Tavily Web Search
4053
tavily_api_key: str = ""
4154

backend/copilot/asr_stream.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -96,19 +96,20 @@ def start(self):
9696
on_error=self._on_error,
9797
on_close=self._on_close,
9898
)
99-
ok = self._transcriber.start(
100-
aformat="pcm",
101-
sample_rate=16000,
102-
enable_intermediate_result=True,
103-
enable_punctuation_prediction=True,
104-
enable_inverse_text_normalization=True,
105-
)
106-
if ok:
99+
try:
100+
self._transcriber.start(
101+
aformat="pcm",
102+
sample_rate=16000,
103+
enable_intermediate_result=True,
104+
enable_punctuation_prediction=True,
105+
enable_inverse_text_normalization=True,
106+
)
107107
self._started = True
108108
logger.info("NLS ASR started")
109-
else:
110-
logger.error("NLS ASR failed to start")
111-
return ok
109+
return True
110+
except Exception as e:
111+
logger.error(f"NLS ASR failed to start: {e}")
112+
return False
112113

113114
def send_audio(self, pcm_data: bytes) -> bool:
114115
if self._transcriber and self._started:

backend/copilot/mcts_config.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""MCTS 相关配置和共享数据结构。"""
2+
from __future__ import annotations
3+
4+
import uuid
5+
from dataclasses import dataclass, field
6+
from typing import TYPE_CHECKING
7+
8+
if TYPE_CHECKING:
9+
from backend.config import Settings
10+
11+
12+
@dataclass
13+
class MCTSConfig:
14+
"""MCTS 搜索参数配置。"""
15+
enabled: bool = True
16+
iterations: int = 8
17+
branch_factor: int = 3
18+
rollout_depth: int = 2
19+
max_expansion_depth: int = 3
20+
c_puct: float = 1.4
21+
max_tree_nodes: int = 200
22+
merge_threshold: float = 0.8
23+
search_timeout: float = 25.0
24+
25+
@classmethod
26+
def from_settings(cls, s: Settings) -> MCTSConfig:
27+
return cls(
28+
enabled=s.mcts_enabled,
29+
iterations=s.mcts_iterations,
30+
branch_factor=s.mcts_branch_factor,
31+
rollout_depth=s.mcts_rollout_depth,
32+
c_puct=s.mcts_c_puct,
33+
max_tree_nodes=s.mcts_max_tree_nodes,
34+
merge_threshold=s.mcts_merge_threshold,
35+
search_timeout=s.mcts_search_timeout,
36+
)
37+
38+
39+
@dataclass
40+
class MCTSNode:
41+
"""博弈树节点。"""
42+
id: str
43+
parent_id: str | None = None
44+
children: list[str] = field(default_factory=list)
45+
46+
# 状态
47+
actor: str = "candidate" # "candidate" | "hr"
48+
action: str = ""
49+
action_embedding: list[float] | None = None
50+
51+
# 上下文
52+
topic: str = ""
53+
conversation_snapshot: list[dict] = field(default_factory=list)
54+
55+
# MCTS 统计量
56+
visit_count: int = 0
57+
total_reward: float = 0.0
58+
prior: float = 0.5
59+
60+
# 元数据
61+
strategy_tree_node_id: str | None = None
62+
risk_level: str = "safe"
63+
depth: int = 0
64+
65+
@property
66+
def q_value(self) -> float:
67+
if self.visit_count == 0:
68+
return 0.0
69+
return self.total_reward / self.visit_count
70+
71+
@staticmethod
72+
def make_id() -> str:
73+
return uuid.uuid4().hex[:10]
74+
75+
76+
@dataclass
77+
class StrategyRecommendation:
78+
"""MCTS 搜索结果,推送给前端的推荐。"""
79+
optimal_response_strategy: str = ""
80+
predicted_followups: list[dict] = field(default_factory=list)
81+
danger_zones: list[str] = field(default_factory=list)
82+
win_rate: float = 0.5
83+
best_path: list[dict] = field(default_factory=list)
84+
confidence: float = 0.0
85+
iterations_completed: int = 0

0 commit comments

Comments
 (0)