Skip to content

Commit 8257454

Browse files
init
1 parent e92da9c commit 8257454

7 files changed

Lines changed: 995 additions & 1 deletion

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
/data/
2+
3+
14
# Byte-compiled / optimized / DLL files
25
__pycache__/
36
*.py[cod]

README.md

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,21 @@
1-
# complex2
1+
# Implementation of the $Complex^2$ method proposed in "How to Turn Your Knowledge Graph Embeddings into Generative Models"
2+
3+
THis repository implements a heterogenous-graph version of $Complex^2$ that implicitly constrains edge connections within edge type.
4+
5+
Original citation:
6+
```
7+
@misc{loconte2024turnknowledgegraphembeddings,
8+
title={How to Turn Your Knowledge Graph Embeddings into Generative Models},
9+
author={Lorenzo Loconte and Nicola Di Mauro and Robert Peharz and Antonio Vergari},
10+
year={2024},
11+
eprint={2305.15944},
12+
archivePrefix={arXiv},
13+
primaryClass={cs.LG},
14+
url={https://arxiv.org/abs/2305.15944},
15+
}
16+
```
17+
18+
19+
20+
21+

biokg.ipynb

Lines changed: 744 additions & 0 deletions
Large diffs are not rendered by default.

complex2/data/TriplesDataset.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
2+
3+
import torch
4+
from torch.utils.data import Dataset
5+
import sys
6+
import numpy as np
7+
sys.path.append('../')
8+
9+
class TriplesDataset(Dataset):
10+
""""""
11+
12+
def __init__(self, triples, filter_to_relation=None):
13+
"""
14+
15+
"""
16+
self.pos_heads = torch.tensor(triples['head'], dtype=torch.long)
17+
self.pos_tails = torch.tensor(triples['tail'], dtype=torch.long)
18+
self.pos_relations = torch.tensor(triples['relation'], dtype=torch.long)
19+
20+
if filter_to_relation is not None:
21+
idxs = torch.isin(self.pos_relations, torch.tensor(filter_to_relation, dtype=torch.long)).nonzero(as_tuple=True)[0]
22+
self.pos_heads = self.pos_heads[idxs]
23+
self.pos_tails = self.pos_tails[idxs]
24+
self.pos_relations = self.pos_relations[idxs]
25+
26+
def __len__(self):
27+
return len(self.pos_heads)
28+
29+
def __getitem__(self, idx):
30+
31+
pos_head = self.pos_heads[idx].detach()
32+
pos_tail = self.pos_tails[idx].detach()
33+
pos_relation = self.pos_relations[idx].detach()
34+
35+
return pos_head, pos_tail, pos_relation

complex2/models/ComplEx2.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
'''
2+
Heterogenous version of (distinct from paper implementation)
3+
How to Turn Your Knowledge Graph Embeddings into Generative Models (https://arxiv.org/pdf/2305.15944)
4+
'''
5+
import numpy as np
6+
import torch
7+
import torch_geometric as pyg
8+
import time
9+
import sys
10+
11+
12+
class ComplEx2(torch.nn.Module):
13+
14+
def __init__(self, data, hidden_channels=512, scale_grad_by_freq=False, dtype=torch.float32, dropout=0.):
15+
16+
super().__init__()
17+
18+
self.data = data
19+
self.rel2type = {v.item(0):k for k,v in data['edge_reltype'].items()}
20+
self.relations = self.rel2type.keys()
21+
self.hidden_channels = hidden_channels
22+
self.dtype = dtype
23+
24+
embed_kwargs = {'max_norm':None, # this causes an error when not None by in-place modifications
25+
'scale_grad_by_freq':scale_grad_by_freq,
26+
'dtype':self.dtype}
27+
28+
# these now represent categorical logit embeddings
29+
self.head_embedding_real_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()})
30+
self.head_embedding_imag_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()})
31+
self.tail_embedding_real_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()})
32+
self.tail_embedding_imag_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()})
33+
self.relation_real_embedding = torch.nn.Embedding(len(self.data['edge_index_dict']), embedding_dim=hidden_channels, **embed_kwargs)
34+
self.relation_imag_embedding = torch.nn.Embedding(len(self.data['edge_index_dict']), embedding_dim=hidden_channels, **embed_kwargs)
35+
36+
self.nodetype2int = {nodetype:torch.tensor([i],dtype=torch.long) for i,nodetype in enumerate(data['num_nodes_dict'].keys())}
37+
38+
# weight initialization
39+
for key in self.head_embedding_real_dict:
40+
self.init_params(self.head_embedding_real_dict[key].weight.data)
41+
self.init_params(self.head_embedding_imag_dict[key].weight.data)
42+
self.init_params(self.tail_embedding_real_dict[key].weight.data)
43+
self.init_params(self.tail_embedding_imag_dict[key].weight.data)
44+
self.init_params(self.relation_real_embedding.weight.data)
45+
self.init_params(self.relation_imag_embedding.weight.data)
46+
47+
# Trying to emulate https://github.com/april-tools/gekcs/blob/main/src/kbc/gekc_models.py ; line 579
48+
# number of consistent triples, in this case edge type specific
49+
# I think this justifies the use of log1p
50+
self.eps_dict = {(h,r,t):1/(self.data['num_nodes_dict'][h] * self.data['num_nodes_dict'][t]) for (h,r,t) in self.data['edge_index_dict'].keys()}
51+
52+
self.dropout = dropout
53+
54+
def init_params(self, tensor, init_loc=0, init_scale=10e-3):
55+
# https://github.com/april-tools/gekcs/blob/main/src/kbc/distributions.py ## line 60
56+
# This initial outputs of ComplEx^2 will be approx. normally distributed and centered (in log-space)
57+
init_loc = np.log(tensor.shape[-1]) / 3.0 + 0.5 * (init_scale ** 2)
58+
t = torch.exp(torch.randn(*tensor.shape, dtype=self.dtype) * init_scale - init_loc)
59+
tensor.copy_(t.float())
60+
61+
def partition_function(self, headtype:str, relint:int, tailtype:str) -> torch.tensor:
62+
#Squared ComplEx partition function as described in "How to Turn our KGE in generative models"
63+
#Slightly different version than used in paper, since we are conditioning on relation. Faster and less memory.
64+
#Subject (real) ~ Sr
65+
Sr = self.head_embedding_real_dict[headtype].weight
66+
Si = self.head_embedding_imag_dict[headtype].weight
67+
Pr = self.relation_real_embedding(relint).view(1,-1)
68+
Pi = self.relation_imag_embedding(relint).view(1,-1)
69+
Or = self.tail_embedding_real_dict[tailtype].weight
70+
Oi = self.tail_embedding_imag_dict[tailtype].weight
71+
72+
if self.do_node_real is not None:
73+
Sr = Sr*self.do_node_real[headtype]
74+
Si = Si*self.do_node_imag[headtype]
75+
Or = Or*self.do_node_real[tailtype]
76+
Oi = Oi*self.do_node_imag[tailtype]
77+
78+
SrSr = Sr.T@Sr ; OrOr = Or.T@Or
79+
SiSi = Si.T@Si ; OiOi = Oi.T@Oi
80+
SrSi = Sr.T@Si ; OrOi = Or.T@Oi
81+
SiSr = Si.T@Sr ; OiOr = Oi.T@Or
82+
# SrSi =/= SiSr
83+
84+
A2 = (Pr @ (SrSr * OrOr) @ Pr.T).sum()
85+
B2 = (Pr @ (SiSi * OiOi) @ Pr.T).sum()
86+
C2 = (Pi @ (SrSr * OiOi) @ Pi.T).sum()
87+
D2 = (Pi @ (SiSi * OrOr) @ Pi.T).sum()
88+
AB = (Pr @ (SrSi * OrOi) @ Pr.T).sum() # AB == BA
89+
AC = (Pr @ (SrSr * OrOi) @ Pi.T).sum()
90+
AD = (Pr @ (SrSi * OrOr) @ Pi.T).sum()
91+
BC = (Pr @ (SiSr * OiOi) @ Pi.T).sum()
92+
BD = (Pr @ (SiSi * OiOr) @ Pi.T).sum()
93+
CD = (Pi @ (SrSi * OiOr) @ Pi.T).sum()
94+
95+
return A2 + B2 + C2 + D2 + 2*AB + 2*AC + 2*BC - 2*AD - 2*BD - 2*CD
96+
97+
98+
def score(self, head_idx, relation_idx, tail_idx, headtype, tailtype):
99+
100+
u_re = self.head_embedding_real_dict[headtype](head_idx)
101+
u_im = self.head_embedding_imag_dict[headtype](head_idx)
102+
v_re = self.tail_embedding_real_dict[tailtype](tail_idx)
103+
v_im = self.tail_embedding_imag_dict[tailtype](tail_idx)
104+
r_re = self.relation_real_embedding(relation_idx)
105+
r_im = self.relation_imag_embedding(relation_idx)
106+
107+
if self.do_node_real is not None:
108+
u_re *= self.do_node_real[headtype][head_idx]
109+
u_im *= self.do_node_imag[headtype][head_idx]
110+
v_re *= self.do_node_real[tailtype][tail_idx]
111+
v_im *= self.do_node_imag[tailtype][tail_idx]
112+
113+
scores = (triple_dot(u_re, r_re, v_re) +
114+
triple_dot(u_im, r_re, v_im) +
115+
triple_dot(u_re, r_im, v_im) -
116+
triple_dot(u_im, r_im, v_re))**2
117+
118+
return scores
119+
120+
def set_dropout_masks(self, device):
121+
122+
if (self.training) and (self.dropout > 0):
123+
124+
self.do_node_real = {}
125+
self.do_node_imag = {}
126+
for nodetype in self.head_embedding_imag_dict.keys():
127+
self.do_node_real[nodetype] = 1.*(torch.rand(size=self.head_embedding_imag_dict[nodetype].weight.size(), device=device) > self.dropout)
128+
self.do_node_imag[nodetype] = 1.*(torch.rand(size=self.head_embedding_imag_dict[nodetype].weight.size(), device=device) > self.dropout)
129+
else:
130+
self.do_node_real = None
131+
self.do_node_imag = None
132+
133+
def forward(self, head, relation, tail):
134+
''''''
135+
log_prob = torch.zeros((head.size(0)), dtype=self.dtype, device=head.device)
136+
137+
self.set_dropout_masks(device=head.device)
138+
139+
for rel in torch.unique(relation):
140+
h,r,t = self.rel2type[rel.item()]
141+
rel_idx = torch.nonzero(relation == rel, as_tuple=True)[0]
142+
Z = self.partition_function(headtype=h, relint=rel, tailtype=t)
143+
phi = self.score(head_idx = head[rel_idx],
144+
relation_idx = relation[rel_idx],
145+
tail_idx = tail[rel_idx],
146+
headtype = h,
147+
tailtype = t)
148+
149+
log_prob[rel_idx] = torch.log(phi + self.eps_dict[(h,r,t)]) - torch.log1p(Z)
150+
151+
return log_prob
152+
153+
def triple_dot(x,y,z):
154+
return (x * y * z).sum(dim=-1)
155+

environment.yml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
name: complex2
2+
3+
channels:
4+
- pytorch
5+
- nvidia
6+
- pyg
7+
- defaults
8+
- conda-forge
9+
10+
dependencies:
11+
- python
12+
- matplotlib
13+
- pyg
14+
- pytorch
15+
- numpy
16+
- pandas
17+
- torchvision
18+
- seaborn
19+
- scikit-learn
20+
- pytorch-cuda
21+
- h5py
22+
- tensorboard
23+
- statsmodels
24+
- pip
25+
- pip:
26+
- ipykernel
27+
- biopython
28+
- openpyxl
29+
- ogb

setup.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from setuptools import setup, find_packages
2+
3+
setup(
4+
name='complex2',
5+
version='0.1',
6+
packages=find_packages(where='.'),
7+
package_dir={'': '.'},
8+
)

0 commit comments

Comments
 (0)