-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathknowledge_graph_nolinenode.py
More file actions
430 lines (361 loc) · 15.7 KB
/
knowledge_graph_nolinenode.py
File metadata and controls
430 lines (361 loc) · 15.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
"""
V4 知识图谱 - 简化版 (w/o LineNode)
消融实验:移除 LineNode 演化链,用于证明 LineNode 的贡献
简化设计:
- Entity: 实体节点,只保留最新状态(current_state, updated_at)
- Note: 原始对话
- Entity -[:MENTIONED_IN]-> Note: 直接关系,无演化链
"""
import json
from typing import List, Dict, Optional, Set
from neo4j import GraphDatabase
from config import NEO4J_CONFIG, MTM_CONFIG, EMBEDDING_CONFIG
from embedding_manager import EmbeddingManager
class KnowledgeGraphNoLineNode:
"""
简化版知识图谱 (w/o LineNode)
图结构:
- (:Entity)-[:MENTIONED_IN]->(:Note) # 直接关系,无演化链
- (:Note)-[:NEXT]->(:Note) # 时序边
"""
def __init__(self, debug_log_path=None, embedding_manager=None, neo4j_config=None):
# 允许传入自定义配置(用于批量导入时根据 conv_id 选择数据库)
config = neo4j_config if neo4j_config else NEO4J_CONFIG
self.driver = GraphDatabase.driver(
config["uri"],
auth=(config["user"], config["password"])
)
self.database = config.get("database", "neo4j")
self._alias_map: Dict[str, str] = {}
self._global_summary: str = self._load_global_summary()
self.debug_log_path = debug_log_path
# Embedding Manager
self.use_embedding = EMBEDDING_CONFIG.get("use_embedding", True)
if self.use_embedding:
if embedding_manager is not None:
self.embedding_manager = embedding_manager
else:
self.embedding_manager = EmbeddingManager(
model_path=EMBEDDING_CONFIG["model_path"],
similarity_threshold=EMBEDDING_CONFIG["similarity_threshold"]
)
self._refresh_embedding_cache()
else:
self.embedding_manager = None
def close(self):
self.driver.close()
def _refresh_embedding_cache(self) -> None:
"""刷新embedding缓存"""
if not self.use_embedding or not self.embedding_manager:
return
entities = self.get_all_entities()
self.embedding_manager.cache_all_entities(entities)
if self.debug_log_path:
with open(self.debug_log_path, "a", encoding="utf-8") as f:
f.write(f"[DEBUG-KG] Cached {len(entities)} entity embeddings\n")
# ==================== 基础操作 ====================
def run_query(self, query: str, params: dict = None) -> List[dict]:
with self.driver.session(database=self.database) as session:
result = session.run(query, params or {})
return [dict(record) for record in result]
def run_write(self, query: str, params: dict = None) -> None:
with self.driver.session(database=self.database) as session:
session.run(query, params or {})
# ==================== Note 操作 ====================
def create_note(self, note: dict) -> None:
"""创建 Note 节点"""
text = f"{note.get('user', '')} {note.get('assistant', '')}"
self.run_write("""
MERGE (n:Note {id: $id})
SET n.session = $session,
n.session_date = $session_date,
n.user = $user,
n.assistant = $assistant,
n.text = $text,
n.seq = $seq
""", {
"id": note["id"],
"session": note.get("session", ""),
"session_date": note.get("session_date", ""),
"user": note.get("user", ""),
"assistant": note.get("assistant", ""),
"text": text,
"seq": int(note["id"][1:]) if note["id"].startswith("N") else 0
})
def create_temporal_edges(self) -> None:
"""创建时序边 (NEXT)"""
self.run_write("""
MATCH (n1:Note), (n2:Note)
WHERE n2.seq = n1.seq + 1
MERGE (n1)-[:NEXT]->(n2)
""")
def get_note_by_id(self, note_id: str) -> Optional[dict]:
result = self.run_query("MATCH (n:Note {id: $id}) RETURN n", {"id": note_id})
return dict(result[0]["n"]) if result else None
def get_notes_by_ids(self, note_ids: List[str]) -> List[dict]:
result = self.run_query("""
MATCH (n:Note) WHERE n.id IN $ids
RETURN n ORDER BY n.seq ASC
""", {"ids": note_ids})
return [dict(r["n"]) for r in result]
def get_all_notes(self) -> List[dict]:
result = self.run_query("MATCH (n:Note) RETURN n ORDER BY n.seq ASC")
return [dict(r["n"]) for r in result]
def get_recent_notes(self, before_seq: int, limit: int = 10) -> List[dict]:
"""获取指定序号之前的最近 N 条 Notes"""
result = self.run_query("""
MATCH (n:Note)
WHERE n.seq < $seq
RETURN n
ORDER BY n.seq DESC
LIMIT $limit
""", {"seq": before_seq, "limit": limit})
return [dict(r["n"]) for r in reversed(result)]
# ==================== Entity 操作 ====================
def create_or_update_entity(
self,
name: str,
display_name: str = "",
entity_type: str = "",
summary: str = "",
heat: int = 1,
current_state: str = "",
updated_at: str = ""
) -> None:
"""
创建或更新实体(简化版)
新增字段:
- current_state: 当前状态描述
- updated_at: 最后更新的 note_id
"""
# 如果 display_name 未提供,默认使用 name
if not display_name:
display_name = name
self.run_write("""
MERGE (e:Entity {name: $name})
SET e.display_name = $display_name,
e.type = $entity_type,
e.summary = CASE WHEN $summary <> '' THEN $summary ELSE COALESCE(e.summary, '') END,
e.heat = COALESCE(e.heat, 0) + $heat,
e.current_state = CASE WHEN $current_state <> '' THEN $current_state ELSE COALESCE(e.current_state, '') END,
e.updated_at = CASE WHEN $updated_at <> '' THEN $updated_at ELSE COALESCE(e.updated_at, '') END
""", {
"name": name,
"display_name": display_name,
"entity_type": entity_type,
"summary": summary,
"heat": heat,
"current_state": current_state,
"updated_at": updated_at
})
# 更新缓存
self._alias_map[name.lower()] = name
# 缓存新实体的embedding
if self.use_embedding and self.embedding_manager:
self.embedding_manager.cache_entity_embedding(name, summary)
def update_entity_summary(self, name: str, summary: str) -> None:
"""更新实体摘要"""
self.run_write("""
MATCH (e:Entity {name: $name})
SET e.summary = $summary
""", {"name": name, "summary": summary})
def get_entity(self, name: str) -> Optional[dict]:
"""获取实体"""
canonical = self.resolve_alias(name)
if not canonical:
return None
result = self.run_query("MATCH (e:Entity {name: $name}) RETURN e", {"name": canonical})
return dict(result[0]["e"]) if result else None
def get_all_entities(self) -> List[dict]:
"""获取所有实体(按热度排序)"""
result = self.run_query("""
MATCH (e:Entity)
RETURN e
ORDER BY e.heat DESC
""")
return [dict(r["e"]) for r in result]
def resolve_alias(self, name: str) -> Optional[str]:
"""解析别名到规范名称(精确匹配优先,embedding作为fallback)"""
name_lower = name.lower()
# 先检查缓存
if name_lower in self._alias_map:
return self._alias_map[name_lower]
# 策略1: 精确匹配 e.name
result = self.run_query("""
MATCH (e:Entity)
WHERE e.name = $name
RETURN e.name as name
LIMIT 1
""", {"name": name_lower})
if result:
canonical = result[0]["name"]
self._alias_map[name_lower] = canonical
return canonical
# 策略2: Embedding-based 相似度匹配(仅在未精确命中时启用)
if self.use_embedding and self.embedding_manager:
entities = self.get_all_entities()
if entities:
match, score = self.embedding_manager.find_best_match(
query=name,
candidates=[
{"name": e.get("name", ""), "summary": e.get("summary", "")}
for e in entities
],
return_score=True
)
if match:
self._alias_map[name_lower] = match
if self.debug_log_path:
with open(self.debug_log_path, "a", encoding="utf-8") as f:
f.write(f"[DEBUG-KG] Embedding match: '{name}' -> '{match}' (score={score:.4f})\n")
return match
# 未找到匹配
if self.debug_log_path:
with open(self.debug_log_path, "a", encoding="utf-8") as f:
f.write(f"[DEBUG-KG] No match found for: '{name}'\n")
return None
def link_entity_to_note(self, entity_name: str, note_id: str) -> None:
"""
创建 Entity -[:MENTIONED_IN]-> Note 关系
简化版核心方法:替代 LineNode
"""
self.run_write("""
MATCH (e:Entity {name: $entity})
MATCH (n:Note {id: $note_id})
MERGE (e)-[:MENTIONED_IN]->(n)
""", {"entity": entity_name, "note_id": note_id})
def get_entity_notes(self, entity_name: str, limit: int = 50) -> List[dict]:
"""
获取实体相关的所有 Notes(简化版)
通过 MENTIONED_IN 边检索,而不是 LineNode
"""
canonical = self.resolve_alias(entity_name)
if not canonical:
return []
result = self.run_query("""
MATCH (e:Entity {name: $name})-[:MENTIONED_IN]->(n:Note)
RETURN DISTINCT n
ORDER BY n.seq ASC
LIMIT $limit
""", {"name": canonical, "limit": limit})
return [dict(r["n"]) for r in result]
# ==================== 图遍历检索 ====================
def find_common_connections(self, entity_names: List[str]) -> List[dict]:
"""找到多个实体的共同连接(交集查询)"""
if len(entity_names) < 2:
return []
canonical_names = [self.resolve_alias(n) for n in entity_names if self.resolve_alias(n)]
if len(canonical_names) < 2:
return []
# 两个实体的共同连接
if len(canonical_names) == 2:
result = self.run_query("""
MATCH (e1:Entity {name: $name1})-[r1]->(common)<-[r2]-(e2:Entity {name: $name2})
WHERE common:Entity OR common:Activity OR common:Location OR common:Event
RETURN common, type(r1) as rel1, type(r2) as rel2
""", {"name1": canonical_names[0], "name2": canonical_names[1]})
return [{"common": dict(r["common"]), "rel1": r["rel1"], "rel2": r["rel2"]} for r in result]
return []
def graph_traversal(
self,
start_entities: List[str],
max_hops: int = 2,
limit: int = 50,
current_seq: Optional[int] = None,
apply_heat_decay: bool = True
) -> List[dict]:
"""
图遍历检索(简化版)
简化:直接通过 MENTIONED_IN 关系收集相关 Notes
"""
canonical_names = [self.resolve_alias(n) for n in start_entities if self.resolve_alias(n)]
if not canonical_names:
return []
# 获取当前最大 seq
if current_seq is None:
max_seq_result = self.run_query("MATCH (n:Note) RETURN max(n.seq) as max_seq")
current_seq = max_seq_result[0]["max_seq"] if max_seq_result else 100
note_scores = {}
heat_decay_factor = MTM_CONFIG["heat_decay_factor"]
# 直接关联的 Notes(通过 MENTIONED_IN)+ 热度衰减
if apply_heat_decay:
direct = self.run_query("""
MATCH (e:Entity)-[:MENTIONED_IN]->(n:Note)
WHERE e.name IN $names
WITH e, n, ($current_seq - n.seq) as age
RETURN DISTINCT n,
1.0 * ($decay_factor ^ age) as score,
e.heat as entity_heat
""", {
"names": canonical_names,
"current_seq": current_seq,
"decay_factor": heat_decay_factor
})
else:
direct = self.run_query("""
MATCH (e:Entity)-[:MENTIONED_IN]->(n:Note)
WHERE e.name IN $names
RETURN DISTINCT n, 1.0 as score, e.heat as entity_heat
""", {"names": canonical_names})
for r in direct:
note_id = r["n"]["id"]
entity_heat = r.get("entity_heat", 1)
final_score = r["score"] * (1 + entity_heat * 0.01)
if note_id not in note_scores or note_scores[note_id]["score"] < final_score:
note_scores[note_id] = {"note": dict(r["n"]), "score": final_score}
# 排序返回(分数高的优先,同分数则新的优先)
sorted_notes = sorted(note_scores.values(), key=lambda x: (-x["score"], -x["note"]["seq"]))
return [item["note"] for item in sorted_notes[:limit]]
# ==================== 全局摘要 ====================
def _load_global_summary(self) -> str:
"""从 Neo4j 加载全局摘要"""
result = self.run_query("""
MATCH (g:GlobalState {type: 'summary'})
RETURN g.content as content
""")
return result[0]["content"] if result else ""
def get_global_summary(self) -> str:
return self._global_summary
def set_global_summary(self, summary: str) -> None:
self._global_summary = summary
self.run_write("""
MERGE (g:GlobalState {type: 'summary'})
SET g.content = $content, g.updated_at = timestamp()
""", {"content": summary})
# ==================== 统计 ====================
def get_stats(self) -> dict:
stats = {}
result = self.run_query("MATCH (n:Note) RETURN count(n) as c")
stats["note_count"] = result[0]["c"] if result else 0
result = self.run_query("MATCH (e:Entity) RETURN count(e) as c")
stats["entity_count"] = result[0]["c"] if result else 0
result = self.run_query("MATCH ()-[r:MENTIONED_IN]->() RETURN count(r) as c")
stats["mentioned_in_edges"] = result[0]["c"] if result else 0
result = self.run_query("MATCH ()-[r:NEXT]->() RETURN count(r) as c")
stats["temporal_edges"] = result[0]["c"] if result else 0
# 按类型统计实体
result = self.run_query("""
MATCH (e:Entity)
RETURN e.type as type, count(e) as count
""")
for r in result:
if r["type"]:
stats[f"entity_{r['type']}"] = r["count"]
return stats
# ==================== 清理 ====================
def clear_all(self) -> None:
self.run_write("MATCH (n) DETACH DELETE n")
self._alias_map.clear()
self._global_summary = ""
def clear_entities_and_lines(self) -> None:
"""只清除实体和关系,保留 Notes"""
self.run_write("MATCH (e:Entity) DETACH DELETE e")
self.run_write("MATCH (g:GlobalState) DELETE g")
self._alias_map.clear()
self._global_summary = ""
# 保持与 baseline 相同的类名别名
KnowledgeGraphV2 = KnowledgeGraphNoLineNode
if __name__ == "__main__":
kg = KnowledgeGraphNoLineNode()
print("Stats:", kg.get_stats())
kg.close()