|
| 1 | +''' |
| 2 | +Heterogenous version of (distinct from paper implementation) |
| 3 | +How to Turn Your Knowledge Graph Embeddings into Generative Models (https://arxiv.org/pdf/2305.15944) |
| 4 | +''' |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +import torch_geometric as pyg |
| 8 | +import time |
| 9 | +import sys |
| 10 | + |
| 11 | + |
| 12 | +class ComplEx2(torch.nn.Module): |
| 13 | + |
| 14 | + def __init__(self, data, hidden_channels=512, scale_grad_by_freq=False, dtype=torch.float32, dropout=0.): |
| 15 | + |
| 16 | + super().__init__() |
| 17 | + |
| 18 | + self.data = data |
| 19 | + self.rel2type = {v.item(0):k for k,v in data['edge_reltype'].items()} |
| 20 | + self.relations = self.rel2type.keys() |
| 21 | + self.hidden_channels = hidden_channels |
| 22 | + self.dtype = dtype |
| 23 | + |
| 24 | + embed_kwargs = {'max_norm':None, # this causes an error when not None by in-place modifications |
| 25 | + 'scale_grad_by_freq':scale_grad_by_freq, |
| 26 | + 'dtype':self.dtype} |
| 27 | + |
| 28 | + # these now represent categorical logit embeddings |
| 29 | + self.head_embedding_real_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()}) |
| 30 | + self.head_embedding_imag_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()}) |
| 31 | + self.tail_embedding_real_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()}) |
| 32 | + self.tail_embedding_imag_dict = torch.nn.ModuleDict({nodetype:torch.nn.Embedding(num_nodes, embedding_dim=hidden_channels, **embed_kwargs) for nodetype, num_nodes in self.data['num_nodes_dict'].items()}) |
| 33 | + self.relation_real_embedding = torch.nn.Embedding(len(self.data['edge_index_dict']), embedding_dim=hidden_channels, **embed_kwargs) |
| 34 | + self.relation_imag_embedding = torch.nn.Embedding(len(self.data['edge_index_dict']), embedding_dim=hidden_channels, **embed_kwargs) |
| 35 | + |
| 36 | + self.nodetype2int = {nodetype:torch.tensor([i],dtype=torch.long) for i,nodetype in enumerate(data['num_nodes_dict'].keys())} |
| 37 | + |
| 38 | + # weight initialization |
| 39 | + for key in self.head_embedding_real_dict: |
| 40 | + self.init_params(self.head_embedding_real_dict[key].weight.data) |
| 41 | + self.init_params(self.head_embedding_imag_dict[key].weight.data) |
| 42 | + self.init_params(self.tail_embedding_real_dict[key].weight.data) |
| 43 | + self.init_params(self.tail_embedding_imag_dict[key].weight.data) |
| 44 | + self.init_params(self.relation_real_embedding.weight.data) |
| 45 | + self.init_params(self.relation_imag_embedding.weight.data) |
| 46 | + |
| 47 | + # Trying to emulate https://github.com/april-tools/gekcs/blob/main/src/kbc/gekc_models.py ; line 579 |
| 48 | + # number of consistent triples, in this case edge type specific |
| 49 | + # I think this justifies the use of log1p |
| 50 | + self.eps_dict = {(h,r,t):1/(self.data['num_nodes_dict'][h] * self.data['num_nodes_dict'][t]) for (h,r,t) in self.data['edge_index_dict'].keys()} |
| 51 | + |
| 52 | + self.dropout = dropout |
| 53 | + |
| 54 | + def init_params(self, tensor, init_loc=0, init_scale=10e-3): |
| 55 | + # https://github.com/april-tools/gekcs/blob/main/src/kbc/distributions.py ## line 60 |
| 56 | + # This initial outputs of ComplEx^2 will be approx. normally distributed and centered (in log-space) |
| 57 | + init_loc = np.log(tensor.shape[-1]) / 3.0 + 0.5 * (init_scale ** 2) |
| 58 | + t = torch.exp(torch.randn(*tensor.shape, dtype=self.dtype) * init_scale - init_loc) |
| 59 | + tensor.copy_(t.float()) |
| 60 | + |
| 61 | + def partition_function(self, headtype:str, relint:int, tailtype:str) -> torch.tensor: |
| 62 | + #Squared ComplEx partition function as described in "How to Turn our KGE in generative models" |
| 63 | + #Slightly different version than used in paper, since we are conditioning on relation. Faster and less memory. |
| 64 | + #Subject (real) ~ Sr |
| 65 | + Sr = self.head_embedding_real_dict[headtype].weight |
| 66 | + Si = self.head_embedding_imag_dict[headtype].weight |
| 67 | + Pr = self.relation_real_embedding(relint).view(1,-1) |
| 68 | + Pi = self.relation_imag_embedding(relint).view(1,-1) |
| 69 | + Or = self.tail_embedding_real_dict[tailtype].weight |
| 70 | + Oi = self.tail_embedding_imag_dict[tailtype].weight |
| 71 | + |
| 72 | + if self.do_node_real is not None: |
| 73 | + Sr = Sr*self.do_node_real[headtype] |
| 74 | + Si = Si*self.do_node_imag[headtype] |
| 75 | + Or = Or*self.do_node_real[tailtype] |
| 76 | + Oi = Oi*self.do_node_imag[tailtype] |
| 77 | + |
| 78 | + SrSr = Sr.T@Sr ; OrOr = Or.T@Or |
| 79 | + SiSi = Si.T@Si ; OiOi = Oi.T@Oi |
| 80 | + SrSi = Sr.T@Si ; OrOi = Or.T@Oi |
| 81 | + SiSr = Si.T@Sr ; OiOr = Oi.T@Or |
| 82 | + # SrSi =/= SiSr |
| 83 | + |
| 84 | + A2 = (Pr @ (SrSr * OrOr) @ Pr.T).sum() |
| 85 | + B2 = (Pr @ (SiSi * OiOi) @ Pr.T).sum() |
| 86 | + C2 = (Pi @ (SrSr * OiOi) @ Pi.T).sum() |
| 87 | + D2 = (Pi @ (SiSi * OrOr) @ Pi.T).sum() |
| 88 | + AB = (Pr @ (SrSi * OrOi) @ Pr.T).sum() # AB == BA |
| 89 | + AC = (Pr @ (SrSr * OrOi) @ Pi.T).sum() |
| 90 | + AD = (Pr @ (SrSi * OrOr) @ Pi.T).sum() |
| 91 | + BC = (Pr @ (SiSr * OiOi) @ Pi.T).sum() |
| 92 | + BD = (Pr @ (SiSi * OiOr) @ Pi.T).sum() |
| 93 | + CD = (Pi @ (SrSi * OiOr) @ Pi.T).sum() |
| 94 | + |
| 95 | + return A2 + B2 + C2 + D2 + 2*AB + 2*AC + 2*BC - 2*AD - 2*BD - 2*CD |
| 96 | + |
| 97 | + |
| 98 | + def score(self, head_idx, relation_idx, tail_idx, headtype, tailtype): |
| 99 | + |
| 100 | + u_re = self.head_embedding_real_dict[headtype](head_idx) |
| 101 | + u_im = self.head_embedding_imag_dict[headtype](head_idx) |
| 102 | + v_re = self.tail_embedding_real_dict[tailtype](tail_idx) |
| 103 | + v_im = self.tail_embedding_imag_dict[tailtype](tail_idx) |
| 104 | + r_re = self.relation_real_embedding(relation_idx) |
| 105 | + r_im = self.relation_imag_embedding(relation_idx) |
| 106 | + |
| 107 | + if self.do_node_real is not None: |
| 108 | + u_re *= self.do_node_real[headtype][head_idx] |
| 109 | + u_im *= self.do_node_imag[headtype][head_idx] |
| 110 | + v_re *= self.do_node_real[tailtype][tail_idx] |
| 111 | + v_im *= self.do_node_imag[tailtype][tail_idx] |
| 112 | + |
| 113 | + scores = (triple_dot(u_re, r_re, v_re) + |
| 114 | + triple_dot(u_im, r_re, v_im) + |
| 115 | + triple_dot(u_re, r_im, v_im) - |
| 116 | + triple_dot(u_im, r_im, v_re))**2 |
| 117 | + |
| 118 | + return scores |
| 119 | + |
| 120 | + def set_dropout_masks(self, device): |
| 121 | + |
| 122 | + if (self.training) and (self.dropout > 0): |
| 123 | + |
| 124 | + self.do_node_real = {} |
| 125 | + self.do_node_imag = {} |
| 126 | + for nodetype in self.head_embedding_imag_dict.keys(): |
| 127 | + self.do_node_real[nodetype] = 1.*(torch.rand(size=self.head_embedding_imag_dict[nodetype].weight.size(), device=device) > self.dropout) |
| 128 | + self.do_node_imag[nodetype] = 1.*(torch.rand(size=self.head_embedding_imag_dict[nodetype].weight.size(), device=device) > self.dropout) |
| 129 | + else: |
| 130 | + self.do_node_real = None |
| 131 | + self.do_node_imag = None |
| 132 | + |
| 133 | + def forward(self, head, relation, tail): |
| 134 | + '''''' |
| 135 | + log_prob = torch.zeros((head.size(0)), dtype=self.dtype, device=head.device) |
| 136 | + |
| 137 | + self.set_dropout_masks(device=head.device) |
| 138 | + |
| 139 | + for rel in torch.unique(relation): |
| 140 | + h,r,t = self.rel2type[rel.item()] |
| 141 | + rel_idx = torch.nonzero(relation == rel, as_tuple=True)[0] |
| 142 | + Z = self.partition_function(headtype=h, relint=rel, tailtype=t) |
| 143 | + phi = self.score(head_idx = head[rel_idx], |
| 144 | + relation_idx = relation[rel_idx], |
| 145 | + tail_idx = tail[rel_idx], |
| 146 | + headtype = h, |
| 147 | + tailtype = t) |
| 148 | + |
| 149 | + log_prob[rel_idx] = torch.log(phi + self.eps_dict[(h,r,t)]) - torch.log1p(Z) |
| 150 | + |
| 151 | + return log_prob |
| 152 | + |
| 153 | +def triple_dot(x,y,z): |
| 154 | + return (x * y * z).sum(dim=-1) |
| 155 | + |
0 commit comments