-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgraph_retriever.py
More file actions
276 lines (226 loc) · 10.9 KB
/
graph_retriever.py
File metadata and controls
276 lines (226 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
"""
V4 图检索器 - Embedding Rerank 版本
两层架构:LPM (实体提取) + Embedding Rerank (向量相似度排序)
"""
from typing import List, Dict, Optional
from pathlib import Path
from knowledge_graph import KnowledgeGraphV2
from entity_extractor import QueryExtractor
from llm_client import get_llm_client
from config import STM_CONFIG, SYSTEM_CONFIG, LPM_CONFIG, EMBEDDING_CONFIG
from embedding_manager import EmbeddingManager
import numpy as np
class EvoGraphRetriever:
"""
EvoGraph V2 检索器 - 简化版
检索流程:
1. LPM: 从查询中提取实体,匹配图中已有实体
2. 收集实体相关的所有 Notes
3. LLM Rerank: 直接让 LLM 选择最相关的 top-k
"""
def __init__(self, kg: KnowledgeGraphV2, debug_log_path: Optional[Path] = None, embedding_manager=None):
self.kg = kg
self.query_extractor = QueryExtractor()
self.llm = get_llm_client()
# 使用传入的 embedding_manager 或从 kg 获取(避免多线程重复创建)
if embedding_manager:
self.embedding_manager = embedding_manager
elif hasattr(kg, 'embedding_manager') and kg.embedding_manager:
self.embedding_manager = kg.embedding_manager
else:
self.embedding_manager = EmbeddingManager(
model_path=EMBEDDING_CONFIG["model_path"],
similarity_threshold=EMBEDDING_CONFIG["similarity_threshold"]
)
self.top_k = STM_CONFIG.get("top_k", 5)
self.debug = SYSTEM_CONFIG.get("debug", False)
self.debug_log_path = debug_log_path
self._debug_file = None
def _debug_log(self, msg: str):
"""写入 debug 日志"""
if self.debug_log_path:
if self._debug_file is None:
self._debug_file = open(self.debug_log_path, "a", encoding="utf-8")
self._debug_file.write(msg + "\n")
self._debug_file.flush()
def retrieve(self, query: str) -> List[dict]:
"""
简化的两层检索
Returns:
检索到的 Notes 列表
"""
self._debug_log(f"\n{'='*60}")
self._debug_log(f"[DEBUG] Query: {query}")
# ========== LPM 层: 实体提取 ==========
existing_entities = self._get_existing_entities()
self._debug_log(f"[DEBUG] Existing entities count: {len(existing_entities)}")
self._debug_log(f"[DEBUG] Existing entities: {[e['name'] for e in existing_entities[:10]]}{'...' if len(existing_entities) > 10 else ''}")
query_entities, query_keywords = self.query_extractor.extract(query, existing_entities)
self._debug_log(f"[DEBUG] LLM extracted entities: {query_entities}")
self._debug_log(f"[DEBUG] LLM extracted keywords: {query_keywords}")
resolved_entities = self._resolve_entities(query_entities)
self._debug_log(f"[DEBUG] Resolved entities: {resolved_entities}")
self._debug_log(f"[DEBUG] Resolution rate: {len(resolved_entities)}/{len(query_entities)} = {len(resolved_entities)/len(query_entities)*100:.1f}%" if query_entities else "[DEBUG] No entities to resolve")
# ========== 收集候选 Notes ==========
all_notes = []
seen_ids = set()
# 策略1: 实体相关 Notes
if resolved_entities:
entity_notes = self._collect_entity_notes(resolved_entities)
for note in entity_notes:
if note["id"] not in seen_ids:
seen_ids.add(note["id"])
all_notes.append(note)
self._debug_log(f"[DEBUG] Entity notes: {len(entity_notes)}")
# 策略2: 关键词搜索补充(使用 query_keywords)
if query_keywords:
keyword_notes = self._search_by_keywords(query_keywords)
added = 0
for note in keyword_notes:
if note["id"] not in seen_ids:
seen_ids.add(note["id"])
all_notes.append(note)
added += 1
self._debug_log(f"[DEBUG] Keyword notes added: {added}")
# 策略3: 如果还是没有候选,全库搜索
if not all_notes:
self._debug_log("[DEBUG] No candidates, falling back to all notes")
all_notes = self.kg.get_all_notes()[:100]
self._debug_log(f"[DEBUG] Total candidates: {len(all_notes)}")
self._debug_log(f"[DEBUG] Candidate IDs: {[n['id'] for n in all_notes[:10]]}{'...' if len(all_notes) > 10 else ''}")
if not all_notes:
self._debug_log("[DEBUG] ❌ No notes found, returning empty!")
return []
# ========== LLM Rerank ==========
selected_ids = self._llm_rerank(query, all_notes)
self._debug_log(f"[DEBUG] LLM rerank selected: {selected_ids}")
if selected_ids:
return self.kg.get_notes_by_ids(selected_ids)
# 回退:按时序返回前 top_k
return all_notes[:self.top_k]
def _get_existing_entities(self) -> List[dict]:
"""获取 LPM 层实体索引"""
entities = self.kg.get_all_entities()
return [
{
"name": e.get("name", ""),
"type": e.get("type", ""),
"summary": e.get("summary", ""), # 不截断,保留完整summary以提高检索准确性
"aliases": e.get("aliases", [])
}
for e in entities[:LPM_CONFIG.get("max_entities_in_prompt", 50)]
]
def _resolve_entities(self, entity_names: List[str]) -> List[str]:
"""解析实体名称到规范名称"""
resolved = []
for name in entity_names:
canonical = self.kg.resolve_alias(name)
if canonical:
resolved.append(canonical)
self._debug_log(f"[DEBUG] ✓ '{name}' -> '{canonical}'")
else:
self._debug_log(f"[DEBUG] ✗ '{name}' -> NOT FOUND")
return resolved
def _search_by_keywords(self, keywords: List[str], limit: int = 50) -> List[dict]:
"""通过关键词搜索 Notes"""
if not keywords:
return []
# 获取所有 Notes,用关键词过滤
all_notes = self.kg.get_all_notes()
matched = []
for note in all_notes:
text = note.get("text", "").lower()
# 任意一个关键词命中即可
if any(kw.lower() in text for kw in keywords):
matched.append(note)
if len(matched) >= limit:
break
return matched
def _collect_entity_notes(self, entity_names: List[str]) -> List[dict]:
"""收集实体相关的 Notes(多实体交集检索策略)
专利设计:交集优先,最小集回退
- 单实体:直接返回该实体的所有 notes
- 多实体:计算交集(同时提及所有查询实体的 notes)
- 交集为空:回退到关联数量最少的实体对应的集合(不是并集!)
理由:
- 交集可以大幅缩小候选集规模,避免语义漂移
- 并集会引入大量噪声,稀释检索精度
- 最小集噪声最少,信噪比高,更可能包含相关信息
"""
if not entity_names:
return []
# 收集每个实体的 notes
entity_note_map = {} # entity -> {note_id}
all_notes_by_id = {} # note_id -> note
for entity_name in entity_names:
entity_notes = self.kg.get_entity_notes(entity_name, limit=100)
entity_note_map[entity_name] = set()
for note in entity_notes:
entity_note_map[entity_name].add(note["id"])
all_notes_by_id[note["id"]] = note
self._debug_log(f"[DEBUG] Entity note counts: {[(e, len(ids)) for e, ids in entity_note_map.items()]}")
# 单实体:直接返回
if len(entity_names) == 1:
notes = list(all_notes_by_id.values())
self._debug_log(f"[DEBUG] Single entity, returning {len(notes)} notes")
notes.sort(key=lambda x: x.get("seq", 0))
return notes
# 多实体:计算交集
note_id_sets = list(entity_note_map.values())
common_ids = note_id_sets[0].intersection(*note_id_sets[1:])
self._debug_log(f"[DEBUG] Intersection of {len(entity_names)} entities: {len(common_ids)} notes")
if common_ids:
# 交集不为空,使用交集结果
notes = [all_notes_by_id[nid] for nid in common_ids]
self._debug_log(f"[DEBUG] Using intersection: {len(notes)} notes")
else:
# 交集为空,回退到关联数量最少的实体对应的集合
# 专利设计:选择最小集而非并集,因为最小集噪声最少
min_entity = min(entity_note_map.keys(), key=lambda e: len(entity_note_map[e]))
min_note_ids = entity_note_map[min_entity]
notes = [all_notes_by_id[nid] for nid in min_note_ids]
self._debug_log(f"[DEBUG] Intersection empty, fallback to smallest set (entity={min_entity}): {len(notes)} notes")
# 按时序排序
notes.sort(key=lambda x: x.get("seq", 0))
return notes
def _llm_rerank(self, query: str, candidates: List[dict]) -> List[str]:
"""
Embedding 重排序 - 基于向量相似度的语义检索
"""
if len(candidates) <= self.top_k:
return [c["id"] for c in candidates]
# 生成查询向量(已归一化)
query_embedding = self.embedding_manager.model.encode(query, normalize_embeddings=True)
# 计算相似度
similarities = []
for note in candidates:
user_text = note.get('user', '')
assistant_text = note.get('assistant', '')
note_text = f"{user_text} {assistant_text}"
note_embedding = self.embedding_manager.model.encode(note_text, normalize_embeddings=True)
# 余弦相似度(已归一化,直接点积)
similarity = float(np.dot(query_embedding, note_embedding))
similarities.append((note['id'], similarity))
# 按相似度排序
similarities.sort(key=lambda x: x[1], reverse=True)
# 返回 top-k
top_ids = [note_id for note_id, sim in similarities[:self.top_k]]
self._debug_log(f"[DEBUG] Embedding top similarities: {[(nid, f'{sim:.4f}') for nid, sim in similarities[:self.top_k]]}")
return top_ids
# 保留旧名称兼容
GraphRetriever = EvoGraphRetriever
HybridRetriever = EvoGraphRetriever
if __name__ == "__main__":
kg = KnowledgeGraphV2()
retriever = EvoGraphRetriever(kg)
test_queries = [
"When did Jon lose his job as a banker?",
"Which city have both Jean and John visited?",
"What does Jon's dance studio offer?",
]
for query in test_queries:
print(f"\n{'='*60}")
print(f"Query: {query}")
results = retriever.retrieve(query)
print(f"Results: {[r['id'] for r in results]}")
kg.close()