-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembedding_manager.py
More file actions
140 lines (113 loc) · 4.71 KB
/
embedding_manager.py
File metadata and controls
140 lines (113 loc) · 4.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Embedding Manager - 实体向量化和相似度匹配
使用本地 sentence-transformers 模型进行实体消歧
"""
from typing import Optional, List, Tuple, Dict
import numpy as np
from sentence_transformers import SentenceTransformer
class EmbeddingManager:
"""
管理实体的向量表示和相似度计算
设计:
- 实体embedding基于 entity.summary (如果有) 或 entity.name
- 使用余弦相似度进行entity matching
- 缓存所有entity的embedding以提升性能
"""
def __init__(self, model_path: str, similarity_threshold: float = 0.7):
"""
Args:
model_path: 本地模型路径
similarity_threshold: 相似度阈值,超过此值认为是同一实体
"""
self.model = SentenceTransformer(model_path)
self.similarity_threshold = similarity_threshold
self.embedding_dim = self.model.get_sentence_embedding_dimension()
# 缓存: entity_name -> embedding
self._embedding_cache: Dict[str, np.ndarray] = {}
def encode_entity(self, name: str, summary: Optional[str] = None) -> np.ndarray:
"""
生成实体的embedding
Args:
name: 实体名称
summary: 实体摘要(更丰富的语义信息)
Returns:
embedding向量
"""
# 优先使用summary(包含更多语义),回退到name
text = summary if summary else name
return self.model.encode(text, normalize_embeddings=True)
def compute_similarity(self, emb1: np.ndarray, emb2: np.ndarray) -> float:
"""计算两个embedding的余弦相似度"""
return float(np.dot(emb1, emb2))
def find_best_match(
self,
query: str,
candidates: List[Dict],
return_score: bool = False
) -> Optional[str] | Tuple[Optional[str], float]:
"""
从候选实体中找到最匹配的实体
Args:
query: 查询字符串(待消歧的实体名)
candidates: 候选实体列表,每个dict包含 {name, summary}
return_score: 是否返回相似度分数
Returns:
最匹配的实体名称(相似度超过阈值)或 None
如果 return_score=True,返回 (name, score) 或 (None, 0.0)
"""
if not candidates:
return (None, 0.0) if return_score else None
# 编码查询
query_emb = self.model.encode(query, normalize_embeddings=True)
# 计算与所有候选的相似度
best_score = -1.0
best_match = None
for candidate in candidates:
name = candidate.get("name", "")
summary = candidate.get("summary", "")
# 使用缓存
if name in self._embedding_cache:
candidate_emb = self._embedding_cache[name]
else:
candidate_emb = self.encode_entity(name, summary)
self._embedding_cache[name] = candidate_emb
score = self.compute_similarity(query_emb, candidate_emb)
if score > best_score:
best_score = score
best_match = name
# 检查是否超过阈值
if best_score >= self.similarity_threshold:
return (best_match, best_score) if return_score else best_match
else:
return (None, 0.0) if return_score else None
def batch_encode(self, texts: List[str]) -> np.ndarray:
"""批量编码文本"""
return self.model.encode(texts, normalize_embeddings=True, batch_size=32)
def cache_entity_embedding(self, name: str, summary: Optional[str] = None) -> None:
"""预先缓存实体的embedding"""
if name not in self._embedding_cache:
emb = self.encode_entity(name, summary)
self._embedding_cache[name] = emb
def cache_all_entities(self, entities: List[Dict]) -> None:
"""批量缓存所有实体的embedding"""
# 收集未缓存的实体
to_cache = []
names = []
for entity in entities:
name = entity.get("name", "")
if name and name not in self._embedding_cache:
summary = entity.get("summary", "")
text = summary if summary else name
to_cache.append(text)
names.append(name)
# 批量编码
if to_cache:
embeddings = self.batch_encode(to_cache)
for name, emb in zip(names, embeddings):
self._embedding_cache[name] = emb
def clear_cache(self) -> None:
"""清空embedding缓存"""
self._embedding_cache.clear()
def get_cache_size(self) -> int:
"""获取缓存的实体数量"""
return len(self._embedding_cache)