-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathalign_embs.py
More file actions
49 lines (39 loc) · 1.37 KB
/
align_embs.py
File metadata and controls
49 lines (39 loc) · 1.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
Example of embedding's data structure
embs = {
a: [0,0,1,....,0],
b: [1,0,0,....,0],
...
}
"""
import numpy as np
def align_two_embs(emb_to_align, emb_base, common_keys=None):
"""
:param emb_to_align: embedding vectors to be align
:param emb_base: base embedding vectors
:return:
aligned_embeddings of emb_to_align
"""
if not common_keys:
common_keys = list(set(emb_to_align.keys()).intersection(set(emb_base.keys())))
A = np.array([emb_to_align[key] for key in common_keys]).T
B = np.array([emb_base[key] for key in common_keys]).T
M = B.dot(A.T)
u, sigma, v_t = np.linalg.svd(M)
rotation_matrix = u.dot(v_t)
aligned_embedding = {k: rotation_matrix.dot(v) for k, v in emb_to_align.items()}
return aligned_embedding
def align_list_of_embs(emb_list, emb_base):
"""
:param emb_list: list of embedding vectors to be align
:param emb_base: base embedding vectors
:return:
list of aligned_embeddings
"""
common_keys = set.intersection(*[set(emb.keys()) for emb in emb_list])
common_keys = list(common_keys.intersection(set(emb_base.keys())))
aligned_embeddings = []
for emb_to_align in emb_list:
aligned_emb = align_two_embs(emb_to_align, emb_base, common_keys)
aligned_embeddings.append(aligned_emb)
return aligned_embeddings