From c2ef1dd279718b9c56ccc3f540a4dc9c0fca7fcc Mon Sep 17 00:00:00 2001
From: pminervini
Date: Fri, 4 Jun 2021 15:57:23 +0200
Subject: [PATCH 01/25] CQD
---
cqd/__init__.py | 3 +
cqd/base.py | 416 +++++++++++++++++++++++++++++++++++++++
cqd/dataloader.py | 75 +++++++
cqd/discrete.py | 492 ++++++++++++++++++++++++++++++++++++++++++++++
cqd/util.py | 383 ++++++++++++++++++++++++++++++++++++
main.py | 100 +++++++---
6 files changed, 1442 insertions(+), 27 deletions(-)
create mode 100644 cqd/__init__.py
create mode 100644 cqd/base.py
create mode 100644 cqd/dataloader.py
create mode 100644 cqd/discrete.py
create mode 100644 cqd/util.py
diff --git a/cqd/__init__.py b/cqd/__init__.py
new file mode 100644
index 0000000..e50c1db
--- /dev/null
+++ b/cqd/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+
+from cqd.base import CQD
diff --git a/cqd/base.py b/cqd/base.py
new file mode 100644
index 0000000..cce7b83
--- /dev/null
+++ b/cqd/base.py
@@ -0,0 +1,416 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import torch
+import torch.nn as nn
+from torch import optim, Tensor
+import math
+
+from cqd.util import query_to_atoms, create_instructions, top_k_selection
+import cqd.discrete as d2
+
+from typing import Tuple, List, Optional, Dict
+
+
+class N3:
+ def __init__(self, weight: float):
+ self.weight = weight
+
+ def forward(self, factors):
+ norm = 0
+ for f in factors:
+ norm += self.weight * torch.sum(torch.abs(f) ** 3)
+ return norm / factors[0].shape[0]
+
+
+class CQD(nn.Module):
+ MIN_NORM = 'min'
+ PROD_NORM = 'prod'
+ NORMS = {MIN_NORM, PROD_NORM}
+
+ def __init__(self,
+ nentity: int,
+ nrelation: int,
+ rank: int,
+ init_size: float = 1e-3,
+ reg_weight: float = 1e-2,
+ test_batch_size: int = 1,
+ method: str = 'beam',
+ t_norm_name: str = 'prod',
+ k: int = 5,
+ query_name_dict: Optional[Dict] = None,
+ do_sigmoid: bool = False,
+ do_normalize: bool = False):
+ super(CQD, self).__init__()
+
+ self.rank = rank
+ self.nentity = nentity
+ self.nrelation = nrelation
+ self.method = method
+ self.t_norm_name = t_norm_name
+ self.k = k
+ self.query_name_dict = query_name_dict
+
+ sizes = (nentity, nrelation)
+ self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=True) for s in sizes[:2]])
+ self.embeddings[0].weight.data *= init_size
+ self.embeddings[1].weight.data *= init_size
+
+ self.init_size = init_size
+ self.loss_fn = nn.CrossEntropyLoss(reduction='mean')
+ self.regularizer = N3(reg_weight)
+
+ self.do_sigmoid = do_sigmoid
+ self.do_normalize = do_normalize
+
+ # XXX: get rid of this hack
+ test_batch_size = 1000
+ batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
+ self.register_buffer('batch_entity_range', batch_entity_range)
+
+ def split(self,
+ lhs_emb: Tensor,
+ rel_emb: Tensor,
+ rhs_emb: Tensor) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
+ lhs = lhs_emb[..., :self.rank], lhs_emb[..., self.rank:]
+ rel = rel_emb[..., :self.rank], rel_emb[..., self.rank:]
+ rhs = rhs_emb[..., :self.rank], rhs_emb[..., self.rank:]
+ return lhs, rel, rhs
+
+ def loss(self,
+ triples: Tensor) -> Tensor:
+ (scores_o, scores_s), factors = self.score_candidates(triples)
+ l_fit = self.loss_fn(scores_o, triples[:, 2]) + self.loss_fn(scores_s, triples[:, 0])
+ l_reg = self.regularizer.forward(factors)
+ return l_fit + l_reg
+
+ def score_candidates(self,
+ triples: Tensor) -> Tuple[Tuple[Tensor, Tensor], Optional[List[Tensor]]]:
+ lhs_emb = self.embeddings[0](triples[:, 0])
+ rel_emb = self.embeddings[1](triples[:, 1])
+ rhs_emb = self.embeddings[0](triples[:, 2])
+ to_score = self.embeddings[0].weight
+ scores_o, _ = self.score_o(lhs_emb, rel_emb, to_score)
+ scores_s, _ = self.score_s(to_score, rel_emb, rhs_emb)
+ lhs, rel, rhs = self.split(lhs_emb, rel_emb, rhs_emb)
+ factors = self.get_factors(lhs, rel, rhs)
+ return (scores_o, scores_s), factors
+
+ def score_o(self,
+ lhs_emb: Tensor,
+ rel_emb: Tensor,
+ rhs_emb: Tensor,
+ return_factors: bool = False) -> Tuple[Tensor, Optional[List[Tensor]]]:
+ lhs, rel, rhs = self.split(lhs_emb, rel_emb, rhs_emb)
+ score_1 = (lhs[0] * rel[0] - lhs[1] * rel[1]) @ rhs[0].transpose(-1, -2)
+ score_2 = (lhs[1] * rel[0] + lhs[0] * rel[1]) @ rhs[1].transpose(-1, -2)
+ factors = self.get_factors(lhs, rel, rhs) if return_factors else None
+ return score_1 + score_2, factors
+
+ def score_s(self,
+ lhs_emb: Tensor,
+ rel_emb: Tensor,
+ rhs_emb: Tensor,
+ return_factors: bool = False) -> Tuple[Tensor, Optional[List[Tensor]]]:
+ lhs, rel, rhs = self.split(lhs_emb, rel_emb, rhs_emb)
+ score_1 = (rhs[0] * rel[0] + rhs[1] * rel[1]) @ lhs[0].transpose(-1, -2)
+ score_2 = (rhs[1] * rel[0] - rhs[0] * rel[1]) @ lhs[1].transpose(-1, -2)
+ factors = self.get_factors(lhs, rel, rhs) if return_factors else None
+ return score_1 + score_2, factors
+
+ def get_factors(self,
+ lhs: Tuple[Tensor, Tensor],
+ rel: Tuple[Tensor, Tensor],
+ rhs: Tuple[Tensor, Tensor]) -> List[Tensor]:
+ factors = []
+ for term in (lhs, rel, rhs):
+ factors.append(torch.sqrt(term[0] ** 2 + term[1] ** 2))
+ return factors
+
+ def get_full_embeddings(self, queries: Tensor) \
+ -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
+ lhs = rel = rhs = None
+ if torch.sum(queries[:, 0]).item() > 0:
+ lhs = self.embeddings[0](queries[:, 0])
+ if torch.sum(queries[:, 1]).item() > 0:
+ rel = self.embeddings[1](queries[:, 1])
+ if torch.sum(queries[:, 2]).item() > 0:
+ rhs = self.embeddings[0](queries[:, 2])
+ return lhs, rel, rhs
+
+ def batch_t_norm(self, scores: Tensor) -> Tensor:
+ if self.t_norm_name == CQD.MIN_NORM:
+ scores = torch.min(scores, dim=1)[0]
+ elif self.t_norm_name == CQD.PROD_NORM:
+ scores = torch.prod(scores, dim=1)
+ else:
+ raise ValueError(f't_norm must be one of {CQD.NORMS}, got {self.t_norm_name}')
+
+ return scores
+
+ def batch_t_conorm(self, scores: Tensor) -> Tensor:
+ if self.t_norm_name == CQD.MIN_NORM:
+ scores = torch.max(scores, dim=1, keepdim=True)[0]
+ elif self.t_norm_name == CQD.PROD_NORM:
+ scores = torch.sum(scores, dim=1, keepdim=True) - torch.prod(scores, dim=1, keepdim=True)
+ else:
+ raise ValueError(f't_norm must be one of {CQD.NORMS}, got {self.t_norm_name}')
+
+ return scores
+
+ def reduce_query_score(self, atom_scores, conjunction_mask, negation_mask):
+ batch_size, num_atoms, *extra_dims = atom_scores.shape
+
+ atom_scores = torch.sigmoid(atom_scores)
+ scores = atom_scores.clone()
+ scores[negation_mask] = 1 - atom_scores[negation_mask]
+
+ disjunctions = scores[~conjunction_mask].reshape(batch_size, -1, *extra_dims)
+ conjunctions = scores[conjunction_mask].reshape(batch_size, -1, *extra_dims)
+
+ if disjunctions.shape[1] > 0:
+ disjunctions = self.batch_t_conorm(disjunctions)
+
+ conjunctions = torch.cat([disjunctions, conjunctions], dim=1)
+
+ t_norm = self.batch_t_norm(conjunctions)
+ return t_norm
+
+ def forward(self,
+ positive_sample,
+ negative_sample,
+ subsampling_weight,
+ batch_queries_dict: Dict[Tuple, Tensor],
+ batch_idxs_dict):
+ all_idxs = []
+ all_scores = []
+
+ scores = None
+
+ for query_structure, queries in batch_queries_dict.items():
+ batch_size = queries.shape[0]
+ atoms, num_variables, conjunction_mask, negation_mask = query_to_atoms(query_structure, queries)
+
+ all_idxs.extend(batch_idxs_dict[query_structure])
+
+ # [False, True]
+ target_mask = torch.sum(atoms == -num_variables, dim=-1) > 0
+
+ # Offsets identify variables across different batches
+ var_id_offsets = torch.arange(batch_size, device=atoms.device) * num_variables
+ var_id_offsets = var_id_offsets.reshape(-1, 1, 1)
+
+ # Replace negative variable IDs with valid identifiers
+ vars_mask = atoms < 0
+ atoms_offset_vars = -atoms + var_id_offsets
+
+ atoms[vars_mask] = atoms_offset_vars[vars_mask]
+
+ head, rel, tail = atoms[..., 0], atoms[..., 1], atoms[..., 2]
+ head_vars_mask = vars_mask[..., 0]
+
+ with torch.no_grad():
+ h_emb_constants = self.embeddings[0](head)
+ r_emb = self.embeddings[1](rel)
+
+ if 'co' in self.method:
+ h_emb = h_emb_constants
+ if num_variables > 1:
+ # var embedding for ID 0 is unused for ease of implementation
+ var_embs = nn.Embedding((num_variables * batch_size) + 1, self.rank * 2)
+ var_embs.weight.data *= self.init_size
+
+ var_embs.to(atoms.device)
+ optimizer = optim.Adam(var_embs.parameters(), lr=0.1)
+ prev_loss_value = 1000
+ loss_value = 999
+ i = 0
+
+ # CQD-CO optimization loop
+ while i < 1000 and math.fabs(prev_loss_value - loss_value) > 1e-9:
+ prev_loss_value = loss_value
+
+ h_emb = h_emb_constants.clone()
+ # Fill variable positions with optimizable embeddings
+ h_emb[head_vars_mask] = var_embs(head[head_vars_mask])
+
+ t_emb = var_embs(tail)
+ scores, factors = self.score_o(h_emb.unsqueeze(-2),
+ r_emb.unsqueeze(-2),
+ t_emb.unsqueeze(-2),
+ return_factors=True)
+
+ query_score = self.reduce_query_score(scores,
+ conjunction_mask,
+ negation_mask)
+
+ loss = - query_score.mean() + self.regularizer.forward(factors)
+ loss_value = loss.item()
+
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ i += 1
+
+ with torch.no_grad():
+ # Select predicates involving target variable only
+ conjunction_mask = conjunction_mask[target_mask].reshape(batch_size, -1)
+ negation_mask = negation_mask[target_mask].reshape(batch_size, -1)
+
+ target_mask = target_mask.unsqueeze(-1).expand_as(h_emb)
+ emb_size = h_emb.shape[-1]
+ h_emb = h_emb[target_mask].reshape(batch_size, -1, emb_size)
+ r_emb = r_emb[target_mask].reshape(batch_size, -1, emb_size)
+ to_score = self.embeddings[0].weight
+
+ scores, factors = self.score_o(h_emb, r_emb, to_score)
+ query_score = self.reduce_query_score(scores,
+ conjunction_mask,
+ negation_mask)
+ all_scores.append(query_score)
+
+ scores = torch.cat(all_scores, dim=0)
+
+ elif 'continuous' in self.method:
+ graph_type = self.query_name_dict[query_structure]
+
+ chain_instructions = create_instructions(atoms[0])
+ chains = []
+
+ for atom in range(len(atoms[0])):
+ part = atoms[:, atom, :]
+ chain = self.get_full_embeddings(part)
+ chains.append(chain)
+
+ scores = top_k_selection(chains,
+ chain_instructions,
+ graph_type,
+ # score_o takes lhs, rel, rhs
+ scoring_function=lambda rel_, lhs_, rhs_: self.score_o(lhs_, rel_, rhs_)[0],
+ forward_emb=lambda lhs_, rel_: self.score_o(lhs_, rel_, self.embeddings[0].weight)[0],
+ entity_embeddings=self.embeddings[0],
+ candidates=self.k,
+ t_norm=self.t_norm_name,
+ batch_size=1,
+ scores_normalize='default')
+
+ elif 'discrete' in self.method:
+ graph_type = self.query_name_dict[query_structure]
+
+ def t_norm(a: Tensor, b: Tensor) -> Tensor:
+ return torch.minimum(a, b)
+
+ def t_conorm(a: Tensor, b: Tensor) -> Tensor:
+ return torch.maximum(a, b)
+
+ def negation(a: Tensor) -> Tensor:
+ return 1.0 - a
+
+ if self.t_norm_name == CQD.PROD_NORM:
+ def t_norm(a: Tensor, b: Tensor) -> Tensor:
+ return a * b
+
+ def t_conorm(a: Tensor, b: Tensor) -> Tensor:
+ return 1 - ((1 - a) * (1 - b))
+
+ def normalize(scores_: Tensor) -> Tensor:
+ scores_ = scores_ - scores_.min(1, keepdim=True)[0]
+ scores_ = scores_ / scores_.max(1, keepdim=True)[0]
+ return scores_
+
+ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor:
+ res, _ = self.score_o(lhs_, rel_, rhs_)
+ if self.do_sigmoid is True:
+ res = torch.sigmoid(res)
+ if self.do_normalize is True:
+ res = normalize(res)
+ return res
+
+ if graph_type == "1p":
+ scores = d2.query_1p(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function)
+ elif graph_type == "2p":
+ scores = d2.query_2p(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ k=self.k, t_norm=t_norm)
+ elif graph_type == "3p":
+ scores = d2.query_3p(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ k=self.k, t_norm=t_norm)
+ elif graph_type == "2i":
+ scores = d2.query_2i(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function, t_norm=t_norm)
+ elif graph_type == "3i":
+ scores = d2.query_3i(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function, t_norm=t_norm)
+ elif graph_type == "pi":
+ scores = d2.query_pi(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ k=self.k, t_norm=t_norm)
+ elif graph_type == "ip":
+ scores = d2.query_ip(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ t_norm=t_norm)
+ elif graph_type == "2u-DNF":
+ scores = d2.query_2u_dnf(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ t_conorm=t_conorm)
+ elif graph_type == "up-DNF":
+ scores = d2.query_up_dnf(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ t_norm=t_norm, t_conorm=t_conorm)
+ elif graph_type == "2in":
+ scores = d2.query_2in(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ t_norm=t_norm, negation=negation)
+ elif graph_type == "3in":
+ scores = d2.query_3in(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ t_norm=t_norm, negation=negation)
+ elif graph_type == "pin":
+ scores = d2.query_pin(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ k=self.k, t_norm=t_norm, negation=negation)
+ elif graph_type == "pni":
+ scores = d2.query_pni_v2(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ k=self.k, t_norm=t_norm, negation=negation)
+ elif graph_type == "inp":
+ scores = d2.query_inp(entity_embeddings=self.embeddings[0],
+ predicate_embeddings=self.embeddings[1],
+ queries=queries,
+ scoring_function=scoring_function,
+ t_norm=t_norm, negation=negation)
+ else:
+ raise ValueError(f'Unknown query type: {graph_type}')
+
+ return None, scores, None, all_idxs
diff --git a/cqd/dataloader.py b/cqd/dataloader.py
new file mode 100644
index 0000000..b7cd80e
--- /dev/null
+++ b/cqd/dataloader.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import torch
+
+from torch.utils.data import Dataset
+from util import flatten
+
+
+class CQDTrainDataset(Dataset):
+ def __init__(self, queries, nentity, nrelation, negative_sample_size, answer):
+ # queries is a list of (query, query_structure) pairs
+ self.len = len(queries)
+ self.queries = queries
+ self.nentity = nentity
+ self.nrelation = nrelation
+ self.negative_sample_size = negative_sample_size
+ self.answer = answer
+
+ self.qa_lst = []
+ for q, qs in queries:
+ for a in self.answer[q]:
+ qa_entry = (qs, q, a)
+ self.qa_lst += [qa_entry]
+
+ self.qa_len = len(self.qa_lst)
+
+ def __len__(self):
+ # return self.len
+ return self.qa_len
+
+ def __getitem__(self, idx):
+ # query = self.queries[idx][0]
+ query = self.qa_lst[idx][1]
+
+ # query_structure = self.queries[idx][1]
+ query_structure = self.qa_lst[idx][0]
+
+ # tail = np.random.choice(list(self.answer[query]))
+ tail = self.qa_lst[idx][2]
+
+ # subsampling_weight = self.count[query]
+ # subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight]))
+ subsampling_weight = torch.tensor([1.0])
+
+ negative_sample_list = []
+ negative_sample_size = 0
+ while negative_sample_size < self.negative_sample_size:
+ negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size * 2)
+ mask = np.in1d(
+ negative_sample,
+ self.answer[query],
+ assume_unique=True,
+ invert=True
+ )
+ negative_sample = negative_sample[mask]
+ negative_sample_list.append(negative_sample)
+ negative_sample_size += negative_sample.size
+ negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size]
+ negative_sample = torch.from_numpy(negative_sample)
+ positive_sample = torch.LongTensor([tail])
+ return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure
+
+ @staticmethod
+ def collate_fn(data):
+ positive_sample = torch.cat([_[0] for _ in data], dim=0)
+ negative_sample = torch.stack([_[1] for _ in data], dim=0)
+ subsample_weight = torch.cat([_[2] for _ in data], dim=0)
+ query = [_[3] for _ in data]
+ query_structure = [_[4] for _ in data]
+ return positive_sample, negative_sample, subsample_weight, query, query_structure
diff --git a/cqd/discrete.py b/cqd/discrete.py
new file mode 100644
index 0000000..f4d02c1
--- /dev/null
+++ b/cqd/discrete.py
@@ -0,0 +1,492 @@
+# -*- coding: utf-8 -*-
+
+import torch
+from torch import nn, Tensor
+
+from typing import Callable, Tuple, Optional
+
+
+def score_candidates(s_emb: Tensor,
+ p_emb: Tensor,
+ candidates_emb: Tensor,
+ k: Optional[int],
+ entity_embeddings: nn.Module,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
+
+ batch_size = max(s_emb.shape[0], p_emb.shape[0])
+ embedding_size = s_emb.shape[1]
+
+ def reshape(emb: Tensor) -> Tensor:
+ if emb.shape[0] < batch_size:
+ n_copies = batch_size // emb.shape[0]
+ emb = emb.reshape(-1, 1, embedding_size).repeat(1, n_copies, 1).reshape(-1, embedding_size)
+ return emb
+
+ s_emb = reshape(s_emb)
+ p_emb = reshape(p_emb)
+ nb_entities = candidates_emb.shape[0]
+
+ x_k_emb_3d = None
+
+ # [B, N]
+ atom_scores_2d = scoring_function(s_emb, p_emb, candidates_emb)
+ atom_k_scores_2d = atom_scores_2d
+
+ if k is not None:
+ k_ = min(k, nb_entities)
+
+ # [B, K], [B, K]
+ atom_k_scores_2d, atom_k_indices = torch.topk(atom_scores_2d, k=k_, dim=1)
+
+ # [B, K, E]
+ x_k_emb_3d = entity_embeddings(atom_k_indices)
+
+ return atom_k_scores_2d, x_k_emb_3d
+
+
+def query_1p(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor]) -> Tensor:
+ s_emb = entity_embeddings(queries[:, 0])
+ p_emb = predicate_embeddings(queries[:, 1])
+ candidates_emb = entity_embeddings.weight
+
+ assert queries.shape[1] == 2
+
+ res, _ = score_candidates(s_emb=s_emb, p_emb=p_emb,
+ candidates_emb=candidates_emb, k=None,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ return res
+
+
+def query_2p(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ s_emb = entity_embeddings(queries[:, 0])
+ p1_emb = predicate_embeddings(queries[:, 1])
+ p2_emb = predicate_embeddings(queries[:, 2])
+
+ candidates_emb = entity_embeddings.weight
+ nb_entities = candidates_emb.shape[0]
+
+ batch_size = s_emb.shape[0]
+ emb_size = s_emb.shape[1]
+
+ # [B, K], [B, K, E]
+ atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb,
+ candidates_emb=candidates_emb, k=k,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ # [B * K, E]
+ x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size)
+
+ # [B * K, N]
+ atom2_scores_2d, _ = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb,
+ candidates_emb=candidates_emb, k=None,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ # [B, K] -> [B, K, N]
+ atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities)
+ # [B * K, N] -> [B, K, N]
+ atom2_scores_3d = atom2_scores_2d.reshape(batch_size, -1, nb_entities)
+
+ res = t_norm(atom1_scores_3d, atom2_scores_3d)
+
+ # [B, K, N] -> [B, N]
+ res, _ = torch.max(res, dim=1)
+ return res
+
+
+def query_2pn(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ s_emb = entity_embeddings(queries[:, 0])
+ p1_emb = predicate_embeddings(queries[:, 1])
+ p2_emb = predicate_embeddings(queries[:, 2])
+
+ candidates_emb = entity_embeddings.weight
+ nb_entities = candidates_emb.shape[0]
+
+ batch_size = s_emb.shape[0]
+ emb_size = s_emb.shape[1]
+
+ # [B, K], [B, K, E]
+ atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb,
+ candidates_emb=candidates_emb, k=k,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ # [B * K, E]
+ x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size)
+
+ # [B * K, N]
+ atom2_scores_2d, _ = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb,
+ candidates_emb=candidates_emb, k=None,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ atom2_scores_2d = negation(atom2_scores_2d)
+
+ # [B, K] -> [B, K, N]
+ atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities)
+ # [B * K, N] -> [B, K, N]
+ atom2_scores_3d = atom2_scores_2d.reshape(batch_size, -1, nb_entities)
+
+ res = t_norm(atom1_scores_3d, atom2_scores_3d)
+
+ # [B, K, N] -> [B, N]
+ res, _ = torch.max(res, dim=1)
+ return res
+
+
+def query_3p(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ s_emb = entity_embeddings(queries[:, 0])
+ p1_emb = predicate_embeddings(queries[:, 1])
+ p2_emb = predicate_embeddings(queries[:, 2])
+ p3_emb = predicate_embeddings(queries[:, 3])
+
+ candidates_emb = entity_embeddings.weight
+ nb_entities = candidates_emb.shape[0]
+
+ batch_size = s_emb.shape[0]
+ emb_size = s_emb.shape[1]
+
+ # [B, K], [B, K, E]
+ atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb,
+ candidates_emb=candidates_emb, k=k,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ # [B * K, E]
+ x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size)
+
+ # [B * K, K], [B * K, K, E]
+ atom2_k_scores_2d, x2_k_emb_3d = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb,
+ candidates_emb=candidates_emb, k=k,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ # [B * K * K, E]
+ x2_k_emb_2d = x2_k_emb_3d.reshape(-1, emb_size)
+
+ # [B * K * K, N]
+ atom3_scores_2d, _ = score_candidates(s_emb=x2_k_emb_2d, p_emb=p3_emb,
+ candidates_emb=candidates_emb, k=None,
+ entity_embeddings=entity_embeddings,
+ scoring_function=scoring_function)
+
+ # [B, K] -> [B, K, N]
+ atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities)
+
+ # [B * K, K] -> [B, K * K, N]
+ atom2_scores_3d = atom2_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities)
+
+ # [B * K * K, N] -> [B, K * K, N]
+ atom3_scores_3d = atom3_scores_2d.reshape(batch_size, -1, nb_entities)
+
+ atom1_scores_3d = atom1_scores_3d.repeat(1, atom3_scores_3d.shape[1] // atom1_scores_3d.shape[1], 1)
+
+ res = t_norm(atom1_scores_3d, atom2_scores_3d)
+ res = t_norm(res, atom3_scores_3d)
+
+ # [B, K, N] -> [B, N]
+ res, _ = torch.max(res, dim=1)
+ return res
+
+
+def query_2i(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:2], scoring_function=scoring_function)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 2:4], scoring_function=scoring_function)
+
+ res = t_norm(scores_1, scores_2)
+
+ return res
+
+
+def query_3i(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:2], scoring_function=scoring_function)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 2:4], scoring_function=scoring_function)
+ scores_3 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 4:6], scoring_function=scoring_function)
+
+ res = t_norm(scores_1, scores_2)
+ res = t_norm(res, scores_3)
+
+ return res
+
+
+def query_ip(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ # [B, N]
+ scores_1 = query_2i(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:4], scoring_function=scoring_function, t_norm=t_norm)
+
+ # [B, E]
+ p_emb = predicate_embeddings(queries[:, 4])
+
+ batch_size = p_emb.shape[0]
+ emb_size = p_emb.shape[1]
+
+ # [N, E]
+ e_emb = entity_embeddings.weight
+ nb_entities = e_emb.shape[0]
+
+ # [B * N, E]
+ s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size)
+
+ # [B * N, N]
+ scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None,
+ entity_embeddings=entity_embeddings, scoring_function=scoring_function)
+
+ # [B, N, N]
+ scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities)
+ scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities)
+
+ res = t_norm(scores_1, scores_2)
+ res, _ = torch.max(res, dim=1)
+
+ return res
+
+
+def query_pi(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 3:5], scoring_function=scoring_function)
+
+ res = t_norm(scores_1, scores_2)
+
+ return res
+
+
+def query_2in(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:2], scoring_function=scoring_function)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 2:4], scoring_function=scoring_function)
+
+ scores_2 = negation(scores_2)
+
+ res = t_norm(scores_1, scores_2)
+
+ return res
+
+
+def query_3in(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:2], scoring_function=scoring_function)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 2:4], scoring_function=scoring_function)
+ scores_3 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 4:6], scoring_function=scoring_function)
+
+ scores_3 = negation(scores_3)
+
+ res = t_norm(scores_1, scores_2)
+ res = t_norm(res, scores_3)
+
+ return res
+
+
+def query_inp(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ # [B, N]
+ scores_1 = query_2in(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:4], scoring_function=scoring_function, t_norm=t_norm, negation=negation)
+
+ # [B, E]
+ p_emb = predicate_embeddings(queries[:, 5])
+
+ batch_size = p_emb.shape[0]
+ emb_size = p_emb.shape[1]
+
+ # [N, E]
+ e_emb = entity_embeddings.weight
+ nb_entities = e_emb.shape[0]
+
+ # [B * N, E]
+ s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size)
+
+ # [B * N, N]
+ scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None,
+ entity_embeddings=entity_embeddings, scoring_function=scoring_function)
+
+ # [B, N, N]
+ scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities)
+ scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities)
+
+ res = t_norm(scores_1, scores_2)
+ res, _ = torch.max(res, dim=1)
+
+ return res
+
+
+def query_pin(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 3:5], scoring_function=scoring_function)
+
+ scores_2 = negation(scores_2)
+
+ res = t_norm(scores_1, scores_2)
+
+ return res
+
+
+def query_pni_v1(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 4:6], scoring_function=scoring_function)
+
+ scores_1 = negation(scores_1)
+
+ res = t_norm(scores_1, scores_2)
+ return res
+
+
+def query_pni_v2(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ negation: Callable[[Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_2pn(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:3], scoring_function=scoring_function, k=k,
+ t_norm=t_norm, negation=negation)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 4:6], scoring_function=scoring_function)
+
+ res = t_norm(scores_1, scores_2)
+ return res
+
+
+def query_2u_dnf(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_conorm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+
+ scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:2], scoring_function=scoring_function)
+ scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 2:4], scoring_function=scoring_function)
+
+ res = t_conorm(scores_1, scores_2)
+
+ return res
+
+
+def query_up_dnf(entity_embeddings: nn.Module,
+ predicate_embeddings: nn.Module,
+ queries: Tensor,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ t_norm: Callable[[Tensor, Tensor], Tensor],
+ t_conorm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
+ # [B, N]
+ scores_1 = query_2u_dnf(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
+ queries=queries[:, 0:4], scoring_function=scoring_function, t_conorm=t_conorm)
+
+ # [B, E]
+ p_emb = predicate_embeddings(queries[:, 5])
+
+ batch_size = p_emb.shape[0]
+ emb_size = p_emb.shape[1]
+
+ # [N, E]
+ e_emb = entity_embeddings.weight
+ nb_entities = e_emb.shape[0]
+
+ # [B * N, E]
+ s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size)
+
+ # [B * N, N]
+ scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None,
+ entity_embeddings=entity_embeddings, scoring_function=scoring_function)
+
+ # [B, N, N]
+ scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities)
+ scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities)
+
+ res = t_norm(scores_1, scores_2)
+ res, _ = torch.max(res, dim=1)
+
+ return res
diff --git a/cqd/util.py b/cqd/util.py
new file mode 100644
index 0000000..dee0d78
--- /dev/null
+++ b/cqd/util.py
@@ -0,0 +1,383 @@
+# -*- coding: utf-8 -*-
+
+import torch
+from torch import Tensor
+
+import numpy as np
+
+from tqdm import tqdm
+
+from typing import List, Tuple, Callable
+
+
+def flatten_structure(query_structure):
+ if type(query_structure) == str:
+ return [query_structure]
+
+ flat_structure = []
+ for element in query_structure:
+ flat_structure.extend(flatten_structure(element))
+
+ return flat_structure
+
+
+def query_to_atoms(query_structure, flat_ids):
+ flat_structure = flatten_structure(query_structure)
+ batch_size, query_length = flat_ids.shape
+ assert len(flat_structure) == query_length
+
+ query_triples = []
+ variable = 0
+ previous = flat_ids[:, 0]
+ conjunction_mask = []
+ negation_mask = []
+
+ for i in range(1, query_length):
+ if flat_structure[i] == 'r':
+ variable -= 1
+ triples = torch.empty(batch_size, 3,
+ device=flat_ids.device,
+ dtype=torch.long)
+ triples[:, 0] = previous
+ triples[:, 1] = flat_ids[:, i]
+ triples[:, 2] = variable
+
+ query_triples.append(triples)
+ previous = variable
+ conjunction_mask.append(True)
+ negation_mask.append(False)
+ elif flat_structure[i] == 'e':
+ previous = flat_ids[:, i]
+ variable += 1
+ elif flat_structure[i] == 'u':
+ conjunction_mask = [False] * len(conjunction_mask)
+ elif flat_structure[i] == 'n':
+ negation_mask[-1] = True
+
+ atoms = torch.stack(query_triples, dim=1)
+ num_variables = variable * -1
+ conjunction_mask = torch.tensor(conjunction_mask).unsqueeze(0).expand(batch_size, -1)
+ negation_mask = torch.tensor(negation_mask).unsqueeze(0).expand(batch_size, -1)
+
+ return atoms, num_variables, conjunction_mask, negation_mask
+
+
+def create_instructions(chains):
+ instructions = []
+
+ prev_start = None
+ prev_end = None
+
+ path_stack = []
+ start_flag = True
+ for chain_ind, chain in enumerate(chains):
+ if start_flag:
+ prev_end = chain[-1]
+ start_flag = False
+ continue
+
+ if prev_end == chain[0]:
+ instructions.append(f"hop_{chain_ind-1}_{chain_ind}")
+ prev_end = chain[-1]
+ prev_start = chain[0]
+
+ elif prev_end == chain[-1]:
+
+ prev_start = chain[0]
+ prev_end = chain[-1]
+
+ instructions.append(f"intersect_{chain_ind-1}_{chain_ind}")
+ else:
+ path_stack.append(([prev_start, prev_end],chain_ind-1))
+ prev_start = chain[0]
+ prev_end = chain[-1]
+ start_flag = False
+ continue
+
+ if len(path_stack) > 0:
+
+ path_prev_start = path_stack[-1][0][0]
+ path_prev_end = path_stack[-1][0][-1]
+
+ if path_prev_end == chain[-1]:
+
+ prev_start = chain[0]
+ prev_end = chain[-1]
+
+ instructions.append(f"intersect_{path_stack[-1][1]}_{chain_ind}")
+ path_stack.pop()
+ continue
+
+ ans = []
+ for inst in instructions:
+ if ans:
+
+ if 'inter' in inst and ('inter' in ans[-1]):
+ last_ind = inst.split("_")[-1]
+ ans[-1] = ans[-1]+f"_{last_ind}"
+ else:
+ ans.append(inst)
+
+ else:
+ ans.append(inst)
+
+ instructions = ans
+ return instructions
+
+
+def t_norm_fn(tens_1: Tensor, tens_2: Tensor, t_norm: str = 'min') -> Tensor:
+ if 'min' in t_norm:
+ return torch.min(tens_1, tens_2)
+ elif 'prod' in t_norm:
+ return tens_1 * tens_2
+
+
+def t_conorm_fn(tens_1: Tensor, tens_2: Tensor, t_norm: str = 'min') -> Tensor:
+ if 'min' in t_norm:
+ return torch.max(tens_1, tens_2)
+ elif 'prod' in t_norm:
+ return (tens_1 + tens_2) - (tens_1 * tens_2)
+
+
+def make_batches(size: int, batch_size: int) -> List[Tuple[int, int]]:
+ max_batch = int(np.ceil(size / float(batch_size)))
+ res = [(i * batch_size, min(size, (i + 1) * batch_size)) for i in range(0, max_batch)]
+ return res
+
+
+def get_best_candidates(rel: Tensor,
+ arg1: Tensor,
+ forward_emb: Callable[[Tensor, Tensor], Tensor],
+ entity_embeddings: Callable[[Tensor], Tensor],
+ candidates: int = 5,
+ last_step: bool = False) -> Tuple[Tensor, Tensor]:
+ batch_size, embedding_size = rel.shape[0], rel.shape[1]
+
+ # [B, N]
+ scores = forward_emb(arg1, rel)
+
+ if not last_step:
+ # [B, K], [B, K]
+ k = min(candidates, scores.shape[1])
+ z_scores, z_indices = torch.topk(scores, k=k, dim=1)
+ # [B, K, E]
+ z_emb = entity_embeddings(z_indices)
+
+ # XXX: move before return
+ assert z_emb.shape[0] == batch_size
+ assert z_emb.shape[2] == embedding_size
+ else:
+ z_scores = scores
+
+ z_indices = torch.arange(z_scores.shape[1]).view(1, -1).repeat(z_scores.shape[0], 1).to(rel.device)
+ z_emb = entity_embeddings(z_indices)
+
+ return z_scores, z_emb
+
+
+def top_k_selection(chains,
+ chain_instructions,
+ graph_type,
+ scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ forward_emb: Callable[[Tensor, Tensor], Tensor],
+ entity_embeddings: Callable[[Tensor], Tensor],
+ candidates: int = 5,
+ t_norm: str = 'min',
+ batch_size: int = 1,
+ scores_normalize: str = 'default'):
+ res = None
+
+ if 'disj' in graph_type:
+ objective = t_conorm_fn
+ else:
+ objective = t_norm_fn
+
+ nb_queries, embedding_size = chains[0][0].shape[0], chains[0][0].shape[1]
+
+ scores = None
+
+ batches = make_batches(nb_queries, batch_size)
+
+ for batch in tqdm(batches):
+
+ nb_branches = 1
+ nb_ent = 0
+ batch_scores = None
+ candidate_cache = {}
+
+ batch_size = batch[1] - batch[0]
+ dnf_flag = False
+ if 'disj' in graph_type:
+ dnf_flag = True
+
+ for inst_ind, inst in enumerate(chain_instructions):
+ with torch.no_grad():
+ if 'hop' in inst:
+
+ ind_1 = int(inst.split("_")[-2])
+ ind_2 = int(inst.split("_")[-1])
+
+ indices = [ind_1, ind_2]
+
+ if objective == t_conorm_fn and dnf_flag:
+ objective = t_norm_fn
+
+ last_hop = False
+ for hop_num, ind in enumerate(indices):
+ last_step = (inst_ind == len(chain_instructions) - 1) and last_hop
+
+ lhs, rel, rhs = chains[ind]
+
+ if lhs is not None:
+ lhs = lhs[batch[0]:batch[1]]
+ else:
+ # print("MTA BRAT")
+ batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"]
+ lhs = lhs_3d.view(-1, embedding_size)
+
+ rel = rel[batch[0]:batch[1]]
+ rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1)
+ rel = rel.view(-1, embedding_size)
+
+ if f"rhs_{ind}" not in candidate_cache:
+
+ # print("STTEEE MTA")
+ z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step)
+
+ # [Num_queries * Candidates^K]
+ z_scores_1d = z_scores.view(-1)
+ if 'disj' in graph_type or scores_normalize:
+ z_scores_1d = torch.sigmoid(z_scores_1d)
+
+ # B * S
+ nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1]
+ nb_branches = nb_sources // batch_size
+ if not last_step:
+ batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm)
+ else:
+ nb_ent = rhs_3d.shape[1]
+ batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm)
+
+ candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d)
+
+ if not last_hop:
+ candidate_cache[f"lhs_{indices[hop_num + 1]}"] = (batch_scores, rhs_3d)
+
+ else:
+ batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"]
+ candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d)
+ last_hop = True
+ continue
+
+ last_hop = True
+
+ elif 'inter' in inst:
+ ind_1 = int(inst.split("_")[-2])
+ ind_2 = int(inst.split("_")[-1])
+
+ indices = [ind_1, ind_2]
+
+ if objective == t_norm_fn and dnf_flag:
+ objective = t_conorm_fn
+
+ if len(inst.split("_")) > 3:
+ ind_1 = int(inst.split("_")[-3])
+ ind_2 = int(inst.split("_")[-2])
+ ind_3 = int(inst.split("_")[-1])
+
+ indices = [ind_1, ind_2, ind_3]
+
+ for intersection_num, ind in enumerate(indices):
+ last_step = (inst_ind == len(chain_instructions) - 1) # and ind == indices[0]
+
+ lhs, rel, rhs = chains[ind]
+
+ if lhs is not None:
+ lhs = lhs[batch[0]:batch[1]]
+ lhs = lhs.view(-1, 1, embedding_size).repeat(1, nb_branches, 1)
+ lhs = lhs.view(-1, embedding_size)
+
+ else:
+ batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"]
+ lhs = lhs_3d.view(-1, embedding_size)
+ nb_sources = lhs_3d.shape[0] * lhs_3d.shape[1]
+ nb_branches = nb_sources // batch_size
+
+ rel = rel[batch[0]:batch[1]]
+ rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1)
+ rel = rel.view(-1, embedding_size)
+
+ if intersection_num > 0 and 'disj' in graph_type:
+ batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"]
+ rhs = rhs_3d.view(-1, embedding_size)
+ z_scores = scoring_function(rel, lhs, rhs)
+
+ z_scores_1d = z_scores.view(-1)
+ if 'disj' in graph_type or scores_normalize:
+ z_scores_1d = torch.sigmoid(z_scores_1d)
+
+ batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores, t_norm)
+
+ continue
+
+ if f"rhs_{ind}" not in candidate_cache or last_step:
+ z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step)
+
+ # [B * Candidates^K] or [B, S-1, N]
+ z_scores_1d = z_scores.view(-1)
+ if 'disj' in graph_type or scores_normalize:
+ z_scores_1d = torch.sigmoid(z_scores_1d)
+
+ if not last_step:
+ nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1]
+ nb_branches = nb_sources // batch_size
+
+ if not last_step:
+ batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm)
+ else:
+ if ind == indices[0]:
+ nb_ent = rhs_3d.shape[1]
+ else:
+ nb_ent = 1
+
+ batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm)
+ nb_ent = rhs_3d.shape[1]
+
+ candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d)
+
+ if ind == indices[0] and 'disj' in graph_type:
+ count = len(indices) - 1
+ iterator = 1
+ while count > 0:
+ candidate_cache[f"rhs_{indices[intersection_num + iterator]}"] = (
+ batch_scores, rhs_3d)
+ iterator += 1
+ count -= 1
+
+ if ind == indices[-1]:
+ candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d)
+ else:
+ batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"]
+ candidate_cache[f"rhs_{ind + 1}"] = (batch_scores, rhs_3d)
+
+ last_hop = True
+ del lhs, rel
+ continue
+
+ del lhs, rel, rhs, rhs_3d, z_scores_1d, z_scores
+
+ if batch_scores is not None:
+ # [B * entites * S ]
+ # S == K**(V-1)
+ scores_2d = batch_scores.view(batch_size, -1, nb_ent)
+ res, _ = torch.max(scores_2d, dim=1)
+ scores = res if scores is None else torch.cat([scores, res])
+
+ del batch_scores, scores_2d, res, candidate_cache
+
+ else:
+ assert False, "Batch Scores are empty: an error went uncaught."
+
+ res = scores
+
+ return res
diff --git a/main.py b/main.py
index 6b76007..8067c70 100755
--- a/main.py
+++ b/main.py
@@ -8,19 +8,21 @@
import json
import logging
import os
-import random
import numpy as np
import torch
from torch.utils.data import DataLoader
from models import KGReasoning
+from cqd import CQD
+
from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator
+from cqd.dataloader import CQDTrainDataset
+
from tensorboardX import SummaryWriter
-import time
+
import pickle
from collections import defaultdict
-from tqdm import tqdm
-from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple
+from util import flatten_query, parse_time, set_global_seed, eval_tuple
query_name_dict = {('e',('r',)): '1p',
('e', ('r', 'r')): '2p',
@@ -42,6 +44,7 @@
name_query_dict = {value: key for key, value in query_name_dict.items()}
all_tasks = list(name_query_dict.keys()) # ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', '2u-DM', 'up-DNF', 'up-DM']
+
def parse_args(args=None):
parser = argparse.ArgumentParser(
description='Training and Testing Knowledge Graph Embedding Models',
@@ -74,8 +77,17 @@ def parse_args(args=None):
parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET')
parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET')
- parser.add_argument('--geo', default='vec', type=str, choices=['vec', 'box', 'beta'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE')
+ parser.add_argument('--geo', default='vec', type=str, choices=['vec', 'box', 'beta', 'cqd'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE, cqd for CQD')
parser.add_argument('--print_on_screen', action='store_true')
+
+ parser.add_argument('--reg_weight', default=1e-3, type=float)
+ parser.add_argument('--optimizer', choices=['adam', 'adagrad'], default='adam')
+ parser.add_argument('--cqd-type', '--cqd', default='co', type=str, choices=['continuous', 'discrete'])
+ parser.add_argument('--cqd-t-norm', default=CQD.PROD_NORM, type=str, choices=CQD.NORMS)
+ parser.add_argument('--cqd-k', default=5, type=int)
+ parser.add_argument('--cqd-sigmoid-scores', '--cqd-sigmoid', action='store_true', default=False)
+ parser.add_argument('--cqd-normalize-scores', '--cqd-normalize', action='store_true', default=False)
+ parser.add_argument('--use-qa-iterator', action='store_true', default=False)
parser.add_argument('--tasks', default='1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task")
parser.add_argument('--seed', default=0, type=int, help="random seed")
@@ -87,6 +99,7 @@ def parse_args(args=None):
return parser.parse_args(args)
+
def save_model(model, optimizer, save_variable_list, args):
'''
Save the parameters of the model and the optimizer,
@@ -104,6 +117,7 @@ def save_model(model, optimizer, save_variable_list, args):
os.path.join(args.save_path, 'checkpoint')
)
+
def set_logger(args):
'''
Write logs to console and log file
@@ -127,6 +141,7 @@ def set_logger(args):
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)
+
def log_metrics(mode, step, metrics):
'''
Print the evaluation logs
@@ -134,6 +149,7 @@ def log_metrics(mode, step, metrics):
for metric in metrics:
logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric]))
+
def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer):
'''
Evaluate queries in dataloader
@@ -141,7 +157,7 @@ def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, m
average_metrics = defaultdict(float)
all_metrics = defaultdict(float)
- metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
+ metrics = KGReasoning.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict)
num_query_structures = 0
num_queries = 0
for query_structure in metrics:
@@ -161,7 +177,8 @@ def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, m
log_metrics('%s average'%mode, step, average_metrics)
return all_metrics
-
+
+
def load_data(args, tasks):
'''
Load queries and remove queries not in tasks
@@ -193,6 +210,7 @@ def load_data(args, tasks):
return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers
+
def main(args):
set_global_seed(args.seed)
tasks = args.tasks.split('.')
@@ -263,21 +281,26 @@ def main(args):
else:
train_other_queries[query_structure] = train_queries[query_structure]
train_path_queries = flatten_query(train_path_queries)
+
+ TrainDatasetClass = TrainDataset
+ if args.use_qa_iterator is True:
+ TrainDatasetClass = CQDTrainDataset
+
train_path_iterator = SingledirectionalOneShotIterator(DataLoader(
- TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers),
+ TrainDatasetClass(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers),
batch_size=args.batch_size,
shuffle=True,
num_workers=args.cpu_num,
- collate_fn=TrainDataset.collate_fn
+ collate_fn=TrainDatasetClass.collate_fn
))
if len(train_other_queries) > 0:
train_other_queries = flatten_query(train_other_queries)
train_other_iterator = SingledirectionalOneShotIterator(DataLoader(
- TrainDataset(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers),
+ TrainDatasetClass(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers),
batch_size=args.batch_size,
shuffle=True,
num_workers=args.cpu_num,
- collate_fn=TrainDataset.collate_fn
+ collate_fn=TrainDatasetClass.collate_fn
))
else:
train_other_iterator = None
@@ -315,18 +338,39 @@ def main(args):
collate_fn=TestDataset.collate_fn
)
- model = KGReasoning(
- nentity=nentity,
- nrelation=nrelation,
- hidden_dim=args.hidden_dim,
- gamma=args.gamma,
- geo=args.geo,
- use_cuda = args.cuda,
- box_mode=eval_tuple(args.box_mode),
- beta_mode = eval_tuple(args.beta_mode),
- test_batch_size=args.test_batch_size,
- query_name_dict = query_name_dict
- )
+ if args.geo == 'cqd':
+ model = CQD(nentity,
+ nrelation,
+ rank=args.hidden_dim,
+ test_batch_size=args.test_batch_size,
+ reg_weight=args.reg_weight,
+ query_name_dict=query_name_dict,
+ method=args.cqd_type,
+ t_norm_name=args.cqd_t_norm,
+ k=args.cqd_k,
+ do_sigmoid=args.cqd_sigmoid_scores,
+ do_normalize=args.cqd_normalize_scores)
+ else:
+ model = KGReasoning(
+ nentity=nentity,
+ nrelation=nrelation,
+ hidden_dim=args.hidden_dim,
+ gamma=args.gamma,
+ geo=args.geo,
+ use_cuda = args.cuda,
+ box_mode=eval_tuple(args.box_mode),
+ beta_mode = eval_tuple(args.beta_mode),
+ test_batch_size=args.test_batch_size,
+ query_name_dict = query_name_dict
+ )
+
+ name_to_optimizer = {
+ 'adam': torch.optim.Adam,
+ 'adagrad': torch.optim.Adagrad
+ }
+
+ assert args.optimizer in name_to_optimizer
+ OptimizerClass = name_to_optimizer[args.optimizer]
logging.info('Model Parameter Configuration:')
num_params = 0
@@ -341,7 +385,7 @@ def main(args):
if args.do_train:
current_learning_rate = args.learning_rate
- optimizer = torch.optim.Adam(
+ optimizer = OptimizerClass(
filter(lambda p: p.requires_grad, model.parameters()),
lr=current_learning_rate
)
@@ -349,7 +393,8 @@ def main(args):
if args.checkpoint_path is not None:
logging.info('Loading checkpoint %s...' % args.checkpoint_path)
- checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint'))
+ checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint'),
+ map_location=torch.device('cpu') if not args.cuda else None)
init_step = checkpoint['step']
model.load_state_dict(checkpoint['model_state_dict'])
@@ -396,7 +441,7 @@ def main(args):
if step >= warm_up_steps:
current_learning_rate = current_learning_rate / 5
logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
- optimizer = torch.optim.Adam(
+ optimizer = OptimizerClass(
filter(lambda p: p.requires_grad, model.parameters()),
lr=current_learning_rate
)
@@ -445,5 +490,6 @@ def main(args):
logging.info("Training finished!!")
+
if __name__ == '__main__':
- main(parse_args())
\ No newline at end of file
+ main(parse_args())
From ad5bcb2904c9327c72728cd03dc984ac89126747 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Fri, 4 Jun 2021 16:33:21 +0200
Subject: [PATCH 02/25] CQD
---
cqd/base.py | 33 ---------
cqd/discrete.py | 183 ------------------------------------------------
2 files changed, 216 deletions(-)
diff --git a/cqd/base.py b/cqd/base.py
index cce7b83..4152531 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -306,9 +306,6 @@ def t_norm(a: Tensor, b: Tensor) -> Tensor:
def t_conorm(a: Tensor, b: Tensor) -> Tensor:
return torch.maximum(a, b)
- def negation(a: Tensor) -> Tensor:
- return 1.0 - a
-
if self.t_norm_name == CQD.PROD_NORM:
def t_norm(a: Tensor, b: Tensor) -> Tensor:
return a * b
@@ -380,36 +377,6 @@ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor:
queries=queries,
scoring_function=scoring_function,
t_norm=t_norm, t_conorm=t_conorm)
- elif graph_type == "2in":
- scores = d2.query_2in(entity_embeddings=self.embeddings[0],
- predicate_embeddings=self.embeddings[1],
- queries=queries,
- scoring_function=scoring_function,
- t_norm=t_norm, negation=negation)
- elif graph_type == "3in":
- scores = d2.query_3in(entity_embeddings=self.embeddings[0],
- predicate_embeddings=self.embeddings[1],
- queries=queries,
- scoring_function=scoring_function,
- t_norm=t_norm, negation=negation)
- elif graph_type == "pin":
- scores = d2.query_pin(entity_embeddings=self.embeddings[0],
- predicate_embeddings=self.embeddings[1],
- queries=queries,
- scoring_function=scoring_function,
- k=self.k, t_norm=t_norm, negation=negation)
- elif graph_type == "pni":
- scores = d2.query_pni_v2(entity_embeddings=self.embeddings[0],
- predicate_embeddings=self.embeddings[1],
- queries=queries,
- scoring_function=scoring_function,
- k=self.k, t_norm=t_norm, negation=negation)
- elif graph_type == "inp":
- scores = d2.query_inp(entity_embeddings=self.embeddings[0],
- predicate_embeddings=self.embeddings[1],
- queries=queries,
- scoring_function=scoring_function,
- t_norm=t_norm, negation=negation)
else:
raise ValueError(f'Unknown query type: {graph_type}')
diff --git a/cqd/discrete.py b/cqd/discrete.py
index f4d02c1..c4f9c88 100644
--- a/cqd/discrete.py
+++ b/cqd/discrete.py
@@ -106,53 +106,6 @@ def query_2p(entity_embeddings: nn.Module,
return res
-def query_2pn(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- k: int,
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- s_emb = entity_embeddings(queries[:, 0])
- p1_emb = predicate_embeddings(queries[:, 1])
- p2_emb = predicate_embeddings(queries[:, 2])
-
- candidates_emb = entity_embeddings.weight
- nb_entities = candidates_emb.shape[0]
-
- batch_size = s_emb.shape[0]
- emb_size = s_emb.shape[1]
-
- # [B, K], [B, K, E]
- atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb,
- candidates_emb=candidates_emb, k=k,
- entity_embeddings=entity_embeddings,
- scoring_function=scoring_function)
-
- # [B * K, E]
- x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size)
-
- # [B * K, N]
- atom2_scores_2d, _ = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb,
- candidates_emb=candidates_emb, k=None,
- entity_embeddings=entity_embeddings,
- scoring_function=scoring_function)
-
- atom2_scores_2d = negation(atom2_scores_2d)
-
- # [B, K] -> [B, K, N]
- atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities)
- # [B * K, N] -> [B, K, N]
- atom2_scores_3d = atom2_scores_2d.reshape(batch_size, -1, nb_entities)
-
- res = t_norm(atom1_scores_3d, atom2_scores_3d)
-
- # [B, K, N] -> [B, N]
- res, _ = torch.max(res, dim=1)
- return res
-
-
def query_3p(entity_embeddings: nn.Module,
predicate_embeddings: nn.Module,
queries: Tensor,
@@ -303,142 +256,6 @@ def query_pi(entity_embeddings: nn.Module,
return res
-def query_2in(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 0:2], scoring_function=scoring_function)
- scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 2:4], scoring_function=scoring_function)
-
- scores_2 = negation(scores_2)
-
- res = t_norm(scores_1, scores_2)
-
- return res
-
-
-def query_3in(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 0:2], scoring_function=scoring_function)
- scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 2:4], scoring_function=scoring_function)
- scores_3 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 4:6], scoring_function=scoring_function)
-
- scores_3 = negation(scores_3)
-
- res = t_norm(scores_1, scores_2)
- res = t_norm(res, scores_3)
-
- return res
-
-
-def query_inp(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- # [B, N]
- scores_1 = query_2in(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 0:4], scoring_function=scoring_function, t_norm=t_norm, negation=negation)
-
- # [B, E]
- p_emb = predicate_embeddings(queries[:, 5])
-
- batch_size = p_emb.shape[0]
- emb_size = p_emb.shape[1]
-
- # [N, E]
- e_emb = entity_embeddings.weight
- nb_entities = e_emb.shape[0]
-
- # [B * N, E]
- s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size)
-
- # [B * N, N]
- scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None,
- entity_embeddings=entity_embeddings, scoring_function=scoring_function)
-
- # [B, N, N]
- scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities)
- scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities)
-
- res = t_norm(scores_1, scores_2)
- res, _ = torch.max(res, dim=1)
-
- return res
-
-
-def query_pin(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- k: int,
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm)
- scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 3:5], scoring_function=scoring_function)
-
- scores_2 = negation(scores_2)
-
- res = t_norm(scores_1, scores_2)
-
- return res
-
-
-def query_pni_v1(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- k: int,
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm)
- scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 4:6], scoring_function=scoring_function)
-
- scores_1 = negation(scores_1)
-
- res = t_norm(scores_1, scores_2)
- return res
-
-
-def query_pni_v2(entity_embeddings: nn.Module,
- predicate_embeddings: nn.Module,
- queries: Tensor,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- k: int,
- t_norm: Callable[[Tensor, Tensor], Tensor],
- negation: Callable[[Tensor], Tensor]) -> Tensor:
-
- scores_1 = query_2pn(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 0:3], scoring_function=scoring_function, k=k,
- t_norm=t_norm, negation=negation)
- scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings,
- queries=queries[:, 4:6], scoring_function=scoring_function)
-
- res = t_norm(scores_1, scores_2)
- return res
-
-
def query_2u_dnf(entity_embeddings: nn.Module,
predicate_embeddings: nn.Module,
queries: Tensor,
From 09353be1cd452ad625aa5e4bc7a494f93e928b55 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Fri, 4 Jun 2021 16:56:42 +0200
Subject: [PATCH 03/25] update
---
README.md | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/README.md b/README.md
index ee1a5e2..f77716e 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,8 @@
# KGReasoning
-This repo contains several algorithms for multi-hop reasoning on knowledge graphs, including the official Pytorch implementation of [Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs](https://arxiv.org/abs/2010.11465).
+This repo contains several algorithms for multi-hop reasoning on knowledge graphs, including the official PyTorch implementation of [Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs](https://arxiv.org/abs/2010.11465) and a PyTorch implementation of [Complex Query Answering with Neural Link Predictors](https://arxiv.org/abs/2011.03459).
**Models**
+- [x] [CQD](https://arxiv.org/abs/2011.03459)
- [x] [BetaE](https://arxiv.org/abs/2010.11465)
- [x] [Query2box](https://arxiv.org/abs/2002.05969)
- [x] [GQE](https://arxiv.org/abs/1806.01445)
From 472268a35a5ce44cac01f10df58d01ff18d6f761 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:15:03 +0200
Subject: [PATCH 04/25] update
---
cqd/base.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/cqd/base.py b/cqd/base.py
index 4152531..d36811b 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -68,8 +68,8 @@ def __init__(self,
# XXX: get rid of this hack
test_batch_size = 1000
- batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
- self.register_buffer('batch_entity_range', batch_entity_range)
+ self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
+ # self.register_buffer('batch_entity_range', batch_entity_range)
def split(self,
lhs_emb: Tensor,
From 6486b657a7c2e6efc70a6951d97790759aa0f47d Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:15:17 +0200
Subject: [PATCH 05/25] update
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index f77716e..5d3fa82 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,5 @@
# KGReasoning
+
This repo contains several algorithms for multi-hop reasoning on knowledge graphs, including the official PyTorch implementation of [Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs](https://arxiv.org/abs/2010.11465) and a PyTorch implementation of [Complex Query Answering with Neural Link Predictors](https://arxiv.org/abs/2011.03459).
**Models**
From 888f506a7c5833f29905289583f911cf630266c4 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:15:29 +0200
Subject: [PATCH 06/25] update
---
CQD.md | 31 +++++++++++++++++++++++++++++++
1 file changed, 31 insertions(+)
create mode 100644 CQD.md
diff --git a/CQD.md b/CQD.md
new file mode 100644
index 0000000..f1187dc
--- /dev/null
+++ b/CQD.md
@@ -0,0 +1,31 @@
+# Continuous Query Decomposition
+
+This repository contains the official implementation for our ICLR 2021 (Oral, Outstanding Paper Award) paper, [**Complex Query Answering with Neural Link Predictors**](https://openreview.net/forum?id=Mos9F9kDwkz).
+
+```bibtex
+@inproceedings{
+ arakelyan2021complex,
+ title={Complex Query Answering with Neural Link Predictors},
+ author={Erik Arakelyan and Daniel Daza and Pasquale Minervini and Michael Cochez},
+ booktitle={International Conference on Learning Representations},
+ year={2021},
+ url={https://openreview.net/forum?id=Mos9F9kDwkz}
+}
+```
+
+In this work we present CQD, a method that reuses a pretrained link predictor to answer complex queries, by scoring atom predicates independently and aggregating the scores via t-norms and t-conorms.
+
+Our code is based on an implementation of ComplEx-N3 available [here](https://github.com/facebookresearch/kbc).
+
+### 1. Download the pre-trained models
+
+```bash
+$ mkdir models/
+$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-betae.tar.gz
+$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-q2b.tar.gz
+$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-betae.tar.gz
+$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-q2b.tar.gz
+$ wget -c http://data.neuralnoise.com/cqd-models/nell-betae.tar.gz
+$ wget -c http://data.neuralnoise.com/cqd-models/nell-q2b.tar.gz
+$ for z in *.tar.gz; do tar xvfz $z; done
+```
From 244eab67aef70ea3820e53b8eb6161bf4d230780 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:18:59 +0200
Subject: [PATCH 07/25] update
---
main.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/main.py b/main.py
index 8067c70..f8d7204 100755
--- a/main.py
+++ b/main.py
@@ -389,7 +389,7 @@ def main(args):
filter(lambda p: p.requires_grad, model.parameters()),
lr=current_learning_rate
)
- warm_up_steps = args.max_steps // 2
+ warm_up_steps = args.max_steps // 2 if args.warm_up_steps is None else args.warm_up_steps
if args.checkpoint_path is not None:
logging.info('Loading checkpoint %s...' % args.checkpoint_path)
From c586f3f0ed66090fef16e50f5c65282ce6e3d51c Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:21:38 +0200
Subject: [PATCH 08/25] update
---
main.py | 29 ++++++++++++++++-------------
1 file changed, 16 insertions(+), 13 deletions(-)
diff --git a/main.py b/main.py
index f8d7204..108fbd1 100755
--- a/main.py
+++ b/main.py
@@ -226,19 +226,22 @@ def main(args):
else:
prefix = args.prefix
- print ("overwritting args.save_path")
- args.save_path = os.path.join(prefix, args.data_path.split('/')[-1], args.tasks, args.geo)
- if args.geo in ['box']:
- tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode)
- elif args.geo in ['vec']:
- tmp_str = "g-{}".format(args.gamma)
- elif args.geo == 'beta':
- tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode)
-
- if args.checkpoint_path is not None:
- args.save_path = args.checkpoint_path
- else:
- args.save_path = os.path.join(args.save_path, tmp_str, cur_time)
+ if args.save_path is None:
+ print("overwritting args.save_path")
+ args.save_path = os.path.join(prefix, args.data_path.split('/')[-1], args.tasks, args.geo)
+ if args.geo in ['box']:
+ tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode)
+ elif args.geo in ['vec']:
+ tmp_str = "g-{}".format(args.gamma)
+ elif args.geo == 'beta':
+ tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode)
+ elif args.geo == 'cqd':
+ tmp_str = "g-cqd"
+
+ if args.checkpoint_path is not None:
+ args.save_path = args.checkpoint_path
+ else:
+ args.save_path = os.path.join(args.save_path, tmp_str, cur_time)
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
From b79a62e592015a39eaa444f47e60f149a5654c8b Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:23:22 +0200
Subject: [PATCH 09/25] model.train_step -> KGReasoning.train_step (since it is
static)
---
main.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/main.py b/main.py
index 108fbd1..c129072 100755
--- a/main.py
+++ b/main.py
@@ -430,14 +430,14 @@ def main(args):
if step == 2*args.max_steps//3:
args.valid_steps *= 4
- log = model.train_step(model, optimizer, train_path_iterator, args, step)
+ log = KGReasoning.train_step(model, optimizer, train_path_iterator, args, step)
for metric in log:
writer.add_scalar('path_'+metric, log[metric], step)
if train_other_iterator is not None:
- log = model.train_step(model, optimizer, train_other_iterator, args, step)
+ log = KGReasoning.train_step(model, optimizer, train_other_iterator, args, step)
for metric in log:
writer.add_scalar('other_'+metric, log[metric], step)
- log = model.train_step(model, optimizer, train_path_iterator, args, step)
+ log = KGReasoning.train_step(model, optimizer, train_path_iterator, args, step)
training_logs.append(log)
From 656a181ecbc1b796026f35f034232ab955a76e81 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:26:58 +0200
Subject: [PATCH 10/25] update
---
models.py | 36 +++++++++++++++++++-----------------
1 file changed, 19 insertions(+), 17 deletions(-)
diff --git a/models.py b/models.py
index b761009..08ee797 100755
--- a/models.py
+++ b/models.py
@@ -5,20 +5,14 @@
from __future__ import print_function
import logging
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
-from torch.utils.data import DataLoader
-from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator
-import random
-import pickle
-import math
import collections
-import itertools
-import time
+
+from cqd import CQD
+
from tqdm import tqdm
-import os
def Identity(x):
return x
@@ -584,16 +578,24 @@ def train_step(model, optimizer, train_iterator, args, step):
negative_sample = negative_sample.cuda()
subsampling_weight = subsampling_weight.cuda()
- positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
+ if isinstance(model, CQD):
+ input_batch = batch_queries_dict[('e', ('r',))]
+ input_batch = torch.cat((input_batch, positive_sample.unsqueeze(1)), dim=1)
+ loss = model.loss(input_batch)
+
+ positive_sample_loss = negative_sample_loss = loss
+ else:
+ positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict)
+
+ negative_score = F.logsigmoid(-negative_logit).mean(dim=1)
+ positive_score = F.logsigmoid(positive_logit).squeeze(dim=1)
+ positive_sample_loss = - (subsampling_weight * positive_score).sum()
+ negative_sample_loss = - (subsampling_weight * negative_score).sum()
+ positive_sample_loss /= subsampling_weight.sum()
+ negative_sample_loss /= subsampling_weight.sum()
- negative_score = F.logsigmoid(-negative_logit).mean(dim=1)
- positive_score = F.logsigmoid(positive_logit).squeeze(dim=1)
- positive_sample_loss = - (subsampling_weight * positive_score).sum()
- negative_sample_loss = - (subsampling_weight * negative_score).sum()
- positive_sample_loss /= subsampling_weight.sum()
- negative_sample_loss /= subsampling_weight.sum()
+ loss = (positive_sample_loss + negative_sample_loss)/2
- loss = (positive_sample_loss + negative_sample_loss)/2
loss.backward()
optimizer.step()
log = {
From febdf3a5256d472962b4c236b10f107277870a10 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:35:09 +0200
Subject: [PATCH 11/25] PEP8
---
main.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/main.py b/main.py
index c129072..52921dd 100755
--- a/main.py
+++ b/main.py
@@ -360,11 +360,11 @@ def main(args):
hidden_dim=args.hidden_dim,
gamma=args.gamma,
geo=args.geo,
- use_cuda = args.cuda,
+ use_cuda=args.cuda,
box_mode=eval_tuple(args.box_mode),
- beta_mode = eval_tuple(args.beta_mode),
+ beta_mode=eval_tuple(args.beta_mode),
test_batch_size=args.test_batch_size,
- query_name_dict = query_name_dict
+ query_name_dict=query_name_dict
)
name_to_optimizer = {
From 027d10eb69f0700374fafe13c570b7fe813203ec Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 17:42:10 +0200
Subject: [PATCH 12/25] update
---
main.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/main.py b/main.py
index 52921dd..c28a9c7 100755
--- a/main.py
+++ b/main.py
@@ -352,7 +352,8 @@ def main(args):
t_norm_name=args.cqd_t_norm,
k=args.cqd_k,
do_sigmoid=args.cqd_sigmoid_scores,
- do_normalize=args.cqd_normalize_scores)
+ do_normalize=args.cqd_normalize_scores,
+ use_cuda=args.cuda)
else:
model = KGReasoning(
nentity=nentity,
From d57f95b0123e486aa3c143a46065c2ddea7044c1 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 18:30:23 +0200
Subject: [PATCH 13/25] update
---
cqd/base.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/cqd/base.py b/cqd/base.py
index d36811b..a85d280 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -43,7 +43,8 @@ def __init__(self,
k: int = 5,
query_name_dict: Optional[Dict] = None,
do_sigmoid: bool = False,
- do_normalize: bool = False):
+ do_normalize: bool = False,
+ use_cuda: bool = False):
super(CQD, self).__init__()
self.rank = rank
@@ -67,10 +68,16 @@ def __init__(self,
self.do_normalize = do_normalize
# XXX: get rid of this hack
- test_batch_size = 1000
- self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
+ # test_batch_size = 1000
+ # self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
# self.register_buffer('batch_entity_range', batch_entity_range)
+ self.use_cuda = use_cuda
+
+ self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
+ if self.use_cuda is True:
+ self.batch_entity_range = self.batch_entity_range.cuda()
+
def split(self,
lhs_emb: Tensor,
rel_emb: Tensor,
From 03a2450b68ca3290ff2c3bf6c491124328461b86 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 5 Jun 2021 18:51:43 +0200
Subject: [PATCH 14/25] cleanup
---
cqd/base.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/cqd/base.py b/cqd/base.py
index a85d280..4dfefb9 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -67,13 +67,7 @@ def __init__(self,
self.do_sigmoid = do_sigmoid
self.do_normalize = do_normalize
- # XXX: get rid of this hack
- # test_batch_size = 1000
- # self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
- # self.register_buffer('batch_entity_range', batch_entity_range)
-
self.use_cuda = use_cuda
-
self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1)
if self.use_cuda is True:
self.batch_entity_range = self.batch_entity_range.cuda()
From d52e12805015f9df08e5317038729475d9d904af Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sun, 6 Jun 2021 14:48:51 +0200
Subject: [PATCH 15/25] better defaults
---
cqd/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cqd/base.py b/cqd/base.py
index 4dfefb9..91b01e7 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -38,7 +38,7 @@ def __init__(self,
init_size: float = 1e-3,
reg_weight: float = 1e-2,
test_batch_size: int = 1,
- method: str = 'beam',
+ method: str = 'discrete',
t_norm_name: str = 'prod',
k: int = 5,
query_name_dict: Optional[Dict] = None,
From 4eeb63684d72f7060c64ff628349ce4f3791e555 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sun, 6 Jun 2021 14:49:10 +0200
Subject: [PATCH 16/25] continuous optimisation requires gradients
---
models.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/models.py b/models.py
index 08ee797..974b6ff 100755
--- a/models.py
+++ b/models.py
@@ -613,7 +613,10 @@ def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_na
total_steps = len(test_dataloader)
logs = collections.defaultdict(list)
- with torch.no_grad():
+ require_grad = isinstance(model, CQD) and model.method == 'continuous'
+
+ # with torch.no_grad():
+ with torch.set_grad_enabled(require_grad):
for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen):
batch_queries_dict = collections.defaultdict(list)
batch_idxs_dict = collections.defaultdict(list)
From 5a4590c5eb2bde4030038a0b8f74e2368b8cd54f Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sun, 6 Jun 2021 15:08:20 +0200
Subject: [PATCH 17/25] require -> requires
---
models.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/models.py b/models.py
index 974b6ff..07d4857 100755
--- a/models.py
+++ b/models.py
@@ -613,10 +613,10 @@ def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_na
total_steps = len(test_dataloader)
logs = collections.defaultdict(list)
- require_grad = isinstance(model, CQD) and model.method == 'continuous'
+ requires_grad = isinstance(model, CQD) and model.method == 'continuous'
# with torch.no_grad():
- with torch.set_grad_enabled(require_grad):
+ with torch.set_grad_enabled(requires_grad):
for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen):
batch_queries_dict = collections.defaultdict(list)
batch_idxs_dict = collections.defaultdict(list)
From e362fbae2d7ea918253a9e4bc4d9b3e9c73ca048 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Wed, 9 Jun 2021 23:33:55 +0200
Subject: [PATCH 18/25] clearer documentation
---
CQD.md | 9 ++-------
1 file changed, 2 insertions(+), 7 deletions(-)
diff --git a/CQD.md b/CQD.md
index f1187dc..30512c6 100644
--- a/CQD.md
+++ b/CQD.md
@@ -21,11 +21,6 @@ Our code is based on an implementation of ComplEx-N3 available [here](https://gi
```bash
$ mkdir models/
-$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-betae.tar.gz
-$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-q2b.tar.gz
-$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-betae.tar.gz
-$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-q2b.tar.gz
-$ wget -c http://data.neuralnoise.com/cqd-models/nell-betae.tar.gz
-$ wget -c http://data.neuralnoise.com/cqd-models/nell-q2b.tar.gz
-$ for z in *.tar.gz; do tar xvfz $z; done
+$ for i in "fb15k" "fb15k-237" "nell"; do for j in "betae" "q2b"; do wget -c http://data.neuralnoise.com/kgreasoning-cqd/$i-$j.tar.gz; done; done
+$ for i in *.tar.gz; do tar xvfz $i; done
```
From d7ed58e2dc3337b8d0fb55667c9fa10f7d0b263a Mon Sep 17 00:00:00 2001
From: pminervini
Date: Thu, 17 Jun 2021 14:50:33 +0200
Subject: [PATCH 19/25] shorter and clearer code
---
models.py | 15 ++++-----------
1 file changed, 4 insertions(+), 11 deletions(-)
diff --git a/models.py b/models.py
index 07d4857..4a8fcb7 100755
--- a/models.py
+++ b/models.py
@@ -639,18 +639,11 @@ def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_na
if len(argsort) == args.test_batch_size: # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one
ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities
else: # otherwise, create a new torch Tensor for batch_entity_range
+ scatter_src = torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1)
if args.cuda:
- ranking = ranking.scatter_(1,
- argsort,
- torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0],
- 1).cuda()
- ) # achieve the ranking of all entities
- else:
- ranking = ranking.scatter_(1,
- argsort,
- torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0],
- 1)
- ) # achieve the ranking of all entities
+ scatter_src = scatter_src.cuda()
+ # achieve the ranking of all entities
+ ranking = ranking.scatter_(1, argsort, scatter_src)
for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)):
hard_answer = hard_answers[query]
easy_answer = easy_answers[query]
From 6ab0bb06ad972b3aeab9aa14ba18eeb52f72bf70 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Thu, 17 Jun 2021 20:10:22 +0200
Subject: [PATCH 20/25] removing old unused code
---
cqd/base.py | 27 +------
cqd/util.py | 208 ----------------------------------------------------
main.py | 2 +-
3 files changed, 3 insertions(+), 234 deletions(-)
diff --git a/cqd/base.py b/cqd/base.py
index 91b01e7..12e22c1 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -9,7 +9,7 @@
from torch import optim, Tensor
import math
-from cqd.util import query_to_atoms, create_instructions, top_k_selection
+from cqd.util import query_to_atoms
import cqd.discrete as d2
from typing import Tuple, List, Optional, Dict
@@ -217,7 +217,7 @@ def forward(self,
h_emb_constants = self.embeddings[0](head)
r_emb = self.embeddings[1](rel)
- if 'co' in self.method:
+ if 'continuous' in self.method:
h_emb = h_emb_constants
if num_variables > 1:
# var embedding for ID 0 is unused for ease of implementation
@@ -275,29 +275,6 @@ def forward(self,
scores = torch.cat(all_scores, dim=0)
- elif 'continuous' in self.method:
- graph_type = self.query_name_dict[query_structure]
-
- chain_instructions = create_instructions(atoms[0])
- chains = []
-
- for atom in range(len(atoms[0])):
- part = atoms[:, atom, :]
- chain = self.get_full_embeddings(part)
- chains.append(chain)
-
- scores = top_k_selection(chains,
- chain_instructions,
- graph_type,
- # score_o takes lhs, rel, rhs
- scoring_function=lambda rel_, lhs_, rhs_: self.score_o(lhs_, rel_, rhs_)[0],
- forward_emb=lambda lhs_, rel_: self.score_o(lhs_, rel_, self.embeddings[0].weight)[0],
- entity_embeddings=self.embeddings[0],
- candidates=self.k,
- t_norm=self.t_norm_name,
- batch_size=1,
- scores_normalize='default')
-
elif 'discrete' in self.method:
graph_type = self.query_name_dict[query_structure]
diff --git a/cqd/util.py b/cqd/util.py
index dee0d78..a6bcd4c 100644
--- a/cqd/util.py
+++ b/cqd/util.py
@@ -173,211 +173,3 @@ def get_best_candidates(rel: Tensor,
z_emb = entity_embeddings(z_indices)
return z_scores, z_emb
-
-
-def top_k_selection(chains,
- chain_instructions,
- graph_type,
- scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
- forward_emb: Callable[[Tensor, Tensor], Tensor],
- entity_embeddings: Callable[[Tensor], Tensor],
- candidates: int = 5,
- t_norm: str = 'min',
- batch_size: int = 1,
- scores_normalize: str = 'default'):
- res = None
-
- if 'disj' in graph_type:
- objective = t_conorm_fn
- else:
- objective = t_norm_fn
-
- nb_queries, embedding_size = chains[0][0].shape[0], chains[0][0].shape[1]
-
- scores = None
-
- batches = make_batches(nb_queries, batch_size)
-
- for batch in tqdm(batches):
-
- nb_branches = 1
- nb_ent = 0
- batch_scores = None
- candidate_cache = {}
-
- batch_size = batch[1] - batch[0]
- dnf_flag = False
- if 'disj' in graph_type:
- dnf_flag = True
-
- for inst_ind, inst in enumerate(chain_instructions):
- with torch.no_grad():
- if 'hop' in inst:
-
- ind_1 = int(inst.split("_")[-2])
- ind_2 = int(inst.split("_")[-1])
-
- indices = [ind_1, ind_2]
-
- if objective == t_conorm_fn and dnf_flag:
- objective = t_norm_fn
-
- last_hop = False
- for hop_num, ind in enumerate(indices):
- last_step = (inst_ind == len(chain_instructions) - 1) and last_hop
-
- lhs, rel, rhs = chains[ind]
-
- if lhs is not None:
- lhs = lhs[batch[0]:batch[1]]
- else:
- # print("MTA BRAT")
- batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"]
- lhs = lhs_3d.view(-1, embedding_size)
-
- rel = rel[batch[0]:batch[1]]
- rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1)
- rel = rel.view(-1, embedding_size)
-
- if f"rhs_{ind}" not in candidate_cache:
-
- # print("STTEEE MTA")
- z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step)
-
- # [Num_queries * Candidates^K]
- z_scores_1d = z_scores.view(-1)
- if 'disj' in graph_type or scores_normalize:
- z_scores_1d = torch.sigmoid(z_scores_1d)
-
- # B * S
- nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1]
- nb_branches = nb_sources // batch_size
- if not last_step:
- batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm)
- else:
- nb_ent = rhs_3d.shape[1]
- batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm)
-
- candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d)
-
- if not last_hop:
- candidate_cache[f"lhs_{indices[hop_num + 1]}"] = (batch_scores, rhs_3d)
-
- else:
- batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"]
- candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d)
- last_hop = True
- continue
-
- last_hop = True
-
- elif 'inter' in inst:
- ind_1 = int(inst.split("_")[-2])
- ind_2 = int(inst.split("_")[-1])
-
- indices = [ind_1, ind_2]
-
- if objective == t_norm_fn and dnf_flag:
- objective = t_conorm_fn
-
- if len(inst.split("_")) > 3:
- ind_1 = int(inst.split("_")[-3])
- ind_2 = int(inst.split("_")[-2])
- ind_3 = int(inst.split("_")[-1])
-
- indices = [ind_1, ind_2, ind_3]
-
- for intersection_num, ind in enumerate(indices):
- last_step = (inst_ind == len(chain_instructions) - 1) # and ind == indices[0]
-
- lhs, rel, rhs = chains[ind]
-
- if lhs is not None:
- lhs = lhs[batch[0]:batch[1]]
- lhs = lhs.view(-1, 1, embedding_size).repeat(1, nb_branches, 1)
- lhs = lhs.view(-1, embedding_size)
-
- else:
- batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"]
- lhs = lhs_3d.view(-1, embedding_size)
- nb_sources = lhs_3d.shape[0] * lhs_3d.shape[1]
- nb_branches = nb_sources // batch_size
-
- rel = rel[batch[0]:batch[1]]
- rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1)
- rel = rel.view(-1, embedding_size)
-
- if intersection_num > 0 and 'disj' in graph_type:
- batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"]
- rhs = rhs_3d.view(-1, embedding_size)
- z_scores = scoring_function(rel, lhs, rhs)
-
- z_scores_1d = z_scores.view(-1)
- if 'disj' in graph_type or scores_normalize:
- z_scores_1d = torch.sigmoid(z_scores_1d)
-
- batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores, t_norm)
-
- continue
-
- if f"rhs_{ind}" not in candidate_cache or last_step:
- z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step)
-
- # [B * Candidates^K] or [B, S-1, N]
- z_scores_1d = z_scores.view(-1)
- if 'disj' in graph_type or scores_normalize:
- z_scores_1d = torch.sigmoid(z_scores_1d)
-
- if not last_step:
- nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1]
- nb_branches = nb_sources // batch_size
-
- if not last_step:
- batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm)
- else:
- if ind == indices[0]:
- nb_ent = rhs_3d.shape[1]
- else:
- nb_ent = 1
-
- batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm)
- nb_ent = rhs_3d.shape[1]
-
- candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d)
-
- if ind == indices[0] and 'disj' in graph_type:
- count = len(indices) - 1
- iterator = 1
- while count > 0:
- candidate_cache[f"rhs_{indices[intersection_num + iterator]}"] = (
- batch_scores, rhs_3d)
- iterator += 1
- count -= 1
-
- if ind == indices[-1]:
- candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d)
- else:
- batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"]
- candidate_cache[f"rhs_{ind + 1}"] = (batch_scores, rhs_3d)
-
- last_hop = True
- del lhs, rel
- continue
-
- del lhs, rel, rhs, rhs_3d, z_scores_1d, z_scores
-
- if batch_scores is not None:
- # [B * entites * S ]
- # S == K**(V-1)
- scores_2d = batch_scores.view(batch_size, -1, nb_ent)
- res, _ = torch.max(scores_2d, dim=1)
- scores = res if scores is None else torch.cat([scores, res])
-
- del batch_scores, scores_2d, res, candidate_cache
-
- else:
- assert False, "Batch Scores are empty: an error went uncaught."
-
- res = scores
-
- return res
diff --git a/main.py b/main.py
index c28a9c7..6a3c5d1 100755
--- a/main.py
+++ b/main.py
@@ -82,7 +82,7 @@ def parse_args(args=None):
parser.add_argument('--reg_weight', default=1e-3, type=float)
parser.add_argument('--optimizer', choices=['adam', 'adagrad'], default='adam')
- parser.add_argument('--cqd-type', '--cqd', default='co', type=str, choices=['continuous', 'discrete'])
+ parser.add_argument('--cqd-type', '--cqd', default='discrete', type=str, choices=['continuous', 'discrete'])
parser.add_argument('--cqd-t-norm', default=CQD.PROD_NORM, type=str, choices=CQD.NORMS)
parser.add_argument('--cqd-k', default=5, type=int)
parser.add_argument('--cqd-sigmoid-scores', '--cqd-sigmoid', action='store_true', default=False)
From fce04172e460ddab7bfad6bfb2cf159dd5fdcbbf Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sun, 27 Jun 2021 12:20:09 +0200
Subject: [PATCH 21/25] more efficient answering of ip and up queries
---
cqd/base.py | 4 ++--
cqd/discrete.py | 50 +++++++++++++++++++++++++++++++++----------------
2 files changed, 36 insertions(+), 18 deletions(-)
diff --git a/cqd/base.py b/cqd/base.py
index 12e22c1..6f2ac23 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -342,7 +342,7 @@ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor:
predicate_embeddings=self.embeddings[1],
queries=queries,
scoring_function=scoring_function,
- t_norm=t_norm)
+ k=self.k, t_norm=t_norm)
elif graph_type == "2u-DNF":
scores = d2.query_2u_dnf(entity_embeddings=self.embeddings[0],
predicate_embeddings=self.embeddings[1],
@@ -354,7 +354,7 @@ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor:
predicate_embeddings=self.embeddings[1],
queries=queries,
scoring_function=scoring_function,
- t_norm=t_norm, t_conorm=t_conorm)
+ k=self.k, t_norm=t_norm, t_conorm=t_conorm)
else:
raise ValueError(f'Unknown query type: {graph_type}')
diff --git a/cqd/discrete.py b/cqd/discrete.py
index c4f9c88..749ac3c 100644
--- a/cqd/discrete.py
+++ b/cqd/discrete.py
@@ -206,6 +206,7 @@ def query_ip(entity_embeddings: nn.Module,
predicate_embeddings: nn.Module,
queries: Tensor,
scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
# [B, N]
@@ -222,18 +223,26 @@ def query_ip(entity_embeddings: nn.Module,
e_emb = entity_embeddings.weight
nb_entities = e_emb.shape[0]
- # [B * N, E]
- s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size)
+ k_ = min(k, nb_entities)
+
+ # [B, K], [B, K]
+ scores_1_k, scores_1_k_indices = torch.topk(scores_1, k=k_, dim=1)
+
+ # [B, K, E]
+ scores_1_k_emb = entity_embeddings(scores_1_k_indices)
+
+ # [B * K, E]
+ scores_1_k_emb_2d = scores_1_k_emb.reshape(batch_size * k_, emb_size)
- # [B * N, N]
- scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None,
+ # [B * K, N]
+ scores_2, _ = score_candidates(s_emb=scores_1_k_emb_2d, p_emb=p_emb, candidates_emb=e_emb, k=None,
entity_embeddings=entity_embeddings, scoring_function=scoring_function)
- # [B, N, N]
- scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities)
- scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities)
+ # [B * K, N]
+ scores_1_k = scores_1_k.reshape(batch_size, k_, 1).repeat(1, 1, nb_entities)
+ scores_2 = scores_2.reshape(batch_size, k_, nb_entities)
- res = t_norm(scores_1, scores_2)
+ res = t_norm(scores_1_k, scores_2)
res, _ = torch.max(res, dim=1)
return res
@@ -276,6 +285,7 @@ def query_up_dnf(entity_embeddings: nn.Module,
predicate_embeddings: nn.Module,
queries: Tensor,
scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor],
+ k: int,
t_norm: Callable[[Tensor, Tensor], Tensor],
t_conorm: Callable[[Tensor, Tensor], Tensor]) -> Tensor:
# [B, N]
@@ -292,18 +302,26 @@ def query_up_dnf(entity_embeddings: nn.Module,
e_emb = entity_embeddings.weight
nb_entities = e_emb.shape[0]
- # [B * N, E]
- s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size)
+ k_ = min(k, nb_entities)
- # [B * N, N]
- scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None,
+ # [B, K], [B, K]
+ scores_1_k, scores_1_k_indices = torch.topk(scores_1, k=k_, dim=1)
+
+ # [B, K, E]
+ scores_1_k_emb = entity_embeddings(scores_1_k_indices)
+
+ # [B * K, E]
+ scores_1_k_emb_2d = scores_1_k_emb.reshape(batch_size * k_, emb_size)
+
+ # [B * K, N]
+ scores_2, _ = score_candidates(s_emb=scores_1_k_emb_2d, p_emb=p_emb, candidates_emb=e_emb, k=None,
entity_embeddings=entity_embeddings, scoring_function=scoring_function)
- # [B, N, N]
- scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities)
- scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities)
+ # [B * K, N]
+ scores_1_k = scores_1_k.reshape(batch_size, k_, 1).repeat(1, 1, nb_entities)
+ scores_2 = scores_2.reshape(batch_size, k_, nb_entities)
- res = t_norm(scores_1, scores_2)
+ res = t_norm(scores_1_k, scores_2)
res, _ = torch.max(res, dim=1)
return res
From 46e953466b3aa588d21ac7578050d15f27bcf323 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 26 Mar 2022 19:24:50 +0100
Subject: [PATCH 22/25] update
---
CQD.md | 259 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 file changed, 258 insertions(+), 1 deletion(-)
diff --git a/CQD.md b/CQD.md
index 30512c6..afa6d34 100644
--- a/CQD.md
+++ b/CQD.md
@@ -17,10 +17,267 @@ In this work we present CQD, a method that reuses a pretrained link predictor to
Our code is based on an implementation of ComplEx-N3 available [here](https://github.com/facebookresearch/kbc).
-### 1. Download the pre-trained models
+## 1. Download the pre-trained models
+
+To download and decompress the pre-trained models, execute the folloing commands:
```bash
$ mkdir models/
$ for i in "fb15k" "fb15k-237" "nell"; do for j in "betae" "q2b"; do wget -c http://data.neuralnoise.com/kgreasoning-cqd/$i-$j.tar.gz; done; done
$ for i in *.tar.gz; do tar xvfz $i; done
```
+
+## 2. Answer the complex queries
+
+One catch is that the query answering process in CQD depends on some hyperparameters, i.e. the "beam size" `k`, the t-norm to use (e.g. `min` or `prod`), and the normalisation function that maps scores to the `[0, 1]` interval; in our experiments, we select these on the validation set. Here are the commands to execute to evaluate CQD on each type of queries:
+
+### 2.1 -- FB15k
+
+1p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete
+[..]
+Test 1p MRR at step 99999: 0.891426
+Test 1p HITS1 at step 99999: 0.857939
+Test 1p HITS3 at step 99999: 0.915589
+[..]
+```
+
+2p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --valid_steps 20 --tasks 2p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda
+[..]
+Test 2p HITS3 at step 99999: 0.791121
+[..]
+```
+
+3p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --valid_steps 20 --tasks 3p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 4 --cuda
+[..]
+Test 3p HITS3 at step 99999: 0.459223
+[..]
+```
+
+2i queries:
+
+```bash
+PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda
+[..]
+Test 2i HITS3 at step 99999: 0.788954
+[..]
+```
+
+3i queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda
+[..]
+Test 3i HITS3 at step 99999: 0.837378
+[..]
+```
+
+ip queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks ip --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda
+[..]
+Test ip HITS3 at step 99999: 0.649221
+[..]
+```
+
+pi queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks pi --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda
+[..]
+Test pi HITS3 at step 99999: 0.681604
+[..]
+```
+
+2u queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2u --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm min --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 2u-DNF HITS3 at step 99999: 0.853601
+[..]
+```
+
+up queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks up --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm min --cqd-sigmoid --cqd-k 16 --cuda
+[..]
+Test up-DNF HITS3 at step 99999: 0.709496
+[..]
+```
+
+### 2.2 -- FB15k-237
+
+1p queries:
+
+```bash
+$ PYTHONPATH=. python3 PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cuda
+[..]
+Test 1p HITS3 at step 99999: 0.511910
+[..]
+```
+
+2p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda
+[..]
+Test 2p HITS3 at step 99999: 0.286640
+[..]
+```
+
+3p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 4 --cuda
+[..]
+Test 3p HITS3 at step 99999: 0.199947
+[..]
+```
+
+2i queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 2i HITS3 at step 99999: 0.376709
+[..]
+```
+
+3i queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 3i HITS3 at step 99999: 0.488725
+[..]
+```
+
+ip queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks ip --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda
+[..]
+Test ip HITS3 at step 99999: 0.182000
+[..]
+```
+
+pi queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks pi --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 64 --cuda
+[..]
+Test pi HITS3 at step 99999: 0.267872
+[..]
+```
+
+2u queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2u --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm min --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 2u-DNF HITS3 at step 99999: 0.323751
+[..]
+```
+
+up queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks up --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 16 --cuda
+[..]
+Test up-DNF HITS3 at step 99999: 0.225360
+[..]
+```
+
+### 2.2 -- NELL 995
+
+1p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cuda
+[..]
+Test 1p HITS3 at step 99999: 0.663197
+[..]
+```
+
+2p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2p --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda
+[..]
+Test 2p HITS3 at step 99999: 0.351218
+[..]
+```
+
+3p queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --valid_steps 20 --tasks 3p --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 2 --cuda
+[..]
+Test 3p HITS3 at step 99999: 0.263724
+[..]
+```
+
+2i queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2i --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 2i HITS3 at step 99999: 0.422821
+[..]
+```
+
+3i queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3i --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 3i HITS3 at step 99999: 0.538633
+[..]
+```
+
+ip queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks ip --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda
+[..]
+Test ip HITS3 at step 99999: 0.234066
+[..]
+```
+
+pi queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks pi --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 64 --cuda
+[..]
+Test pi HITS3 at step 99999: 0.315222
+[..]
+```
+
+2u queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2u --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm min --cqd-normalize --cqd-k 16 --cuda
+[..]
+Test 2u-DNF HITS3 at step 99999: 0.541287
+[..]
+```
+
+up queries:
+
+```bash
+$ PYTHONPATH=. python3 main.py --do_valid --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks up --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm min --cqd-sigmoid --cqd-k 16 --cuda
+[..]
+Test up-DNF HITS3 at step 99999: 0.290282
+[..]
+```
From fba0ff4121aa225e2929793f434690846cbae22d Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sat, 26 Mar 2022 22:07:15 +0100
Subject: [PATCH 23/25] Fixing issue #2
---
cqd/base.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/cqd/base.py b/cqd/base.py
index 6f2ac23..2120395 100644
--- a/cqd/base.py
+++ b/cqd/base.py
@@ -56,7 +56,7 @@ def __init__(self,
self.query_name_dict = query_name_dict
sizes = (nentity, nrelation)
- self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=True) for s in sizes[:2]])
+ self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=False) for s in sizes[:2]])
self.embeddings[0].weight.data *= init_size
self.embeddings[1].weight.data *= init_size
From 71ec9a52b550407cae9153a293fcf2f775173f88 Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sun, 27 Mar 2022 22:03:34 +0200
Subject: [PATCH 24/25] update
---
CQD.md | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/CQD.md b/CQD.md
index afa6d34..3abb9bd 100644
--- a/CQD.md
+++ b/CQD.md
@@ -19,7 +19,7 @@ Our code is based on an implementation of ComplEx-N3 available [here](https://gi
## 1. Download the pre-trained models
-To download and decompress the pre-trained models, execute the folloing commands:
+To download and decompress the pre-trained models, execute the following commands:
```bash
$ mkdir models/
@@ -27,6 +27,14 @@ $ for i in "fb15k" "fb15k-237" "nell"; do for j in "betae" "q2b"; do wget -c htt
$ for i in *.tar.gz; do tar xvfz $i; done
```
+In case you need to re-train the models from scratch, use the following command lines:
+
+```bash
+PYTHONPATH=. python3 main.py --do_train --do_valid --do_test --data_path data/FB15k-237-q2b -n 1 -b 2000 -d 1000 -lr 0.1 --warm_up_steps 100000000 --max_steps 100000 --cpu_num 0 --geo cqd --valid_steps 500 --tasks 1p --print_on_screen --test_batch_size 1000 --optimizer adagrad --reg_weight 0.1 --log_steps 500 --cuda --use-qa-iterator
+PYTHONPATH=. python3 main.py --do_train --do_valid --do_test --data_path data/FB15k-q2b -n 1 -b 5000 -d 1000 -lr 0.1 --warm_up_steps 100000000 --max_steps 100000 --cpu_num 0 --geo cqd --valid_steps 500 --tasks 1p --print_on_screen --test_batch_size 1000 --optimizer adagrad --reg_weight 0.01 --log_steps 500 --cuda --use-qa-iterator
+PYTHONPATH=. python3 main.py --do_train --do_valid --do_test --data_path data/NELL-q2b -n 1 -b 2000 -d 1000 -lr 0.1 --warm_up_steps 100000000 --max_steps 100000 --cpu_num 0 --geo cqd --valid_steps 500 --tasks 1p --print_on_screen --test_batch_size 1000 --optimizer adagrad --reg_weight 0.1 --log_steps 500 --cuda --use-qa-iterator
+```
+
## 2. Answer the complex queries
One catch is that the query answering process in CQD depends on some hyperparameters, i.e. the "beam size" `k`, the t-norm to use (e.g. `min` or `prod`), and the normalisation function that maps scores to the `[0, 1]` interval; in our experiments, we select these on the validation set. Here are the commands to execute to evaluate CQD on each type of queries:
From 81060a2af7153b0b910704397cff8ea7e71b713a Mon Sep 17 00:00:00 2001
From: pminervini
Date: Sun, 27 Mar 2022 22:05:15 +0200
Subject: [PATCH 25/25] update
---
CQD.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/CQD.md b/CQD.md
index 3abb9bd..7177f07 100644
--- a/CQD.md
+++ b/CQD.md
@@ -129,7 +129,7 @@ Test up-DNF HITS3 at step 99999: 0.709496
1p queries:
```bash
-$ PYTHONPATH=. python3 PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cuda
+$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cuda
[..]
Test 1p HITS3 at step 99999: 0.511910
[..]