From c2ef1dd279718b9c56ccc3f540a4dc9c0fca7fcc Mon Sep 17 00:00:00 2001 From: pminervini Date: Fri, 4 Jun 2021 15:57:23 +0200 Subject: [PATCH 01/25] CQD --- cqd/__init__.py | 3 + cqd/base.py | 416 +++++++++++++++++++++++++++++++++++++++ cqd/dataloader.py | 75 +++++++ cqd/discrete.py | 492 ++++++++++++++++++++++++++++++++++++++++++++++ cqd/util.py | 383 ++++++++++++++++++++++++++++++++++++ main.py | 100 +++++++--- 6 files changed, 1442 insertions(+), 27 deletions(-) create mode 100644 cqd/__init__.py create mode 100644 cqd/base.py create mode 100644 cqd/dataloader.py create mode 100644 cqd/discrete.py create mode 100644 cqd/util.py diff --git a/cqd/__init__.py b/cqd/__init__.py new file mode 100644 index 0000000..e50c1db --- /dev/null +++ b/cqd/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +from cqd.base import CQD diff --git a/cqd/base.py b/cqd/base.py new file mode 100644 index 0000000..cce7b83 --- /dev/null +++ b/cqd/base.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +from torch import optim, Tensor +import math + +from cqd.util import query_to_atoms, create_instructions, top_k_selection +import cqd.discrete as d2 + +from typing import Tuple, List, Optional, Dict + + +class N3: + def __init__(self, weight: float): + self.weight = weight + + def forward(self, factors): + norm = 0 + for f in factors: + norm += self.weight * torch.sum(torch.abs(f) ** 3) + return norm / factors[0].shape[0] + + +class CQD(nn.Module): + MIN_NORM = 'min' + PROD_NORM = 'prod' + NORMS = {MIN_NORM, PROD_NORM} + + def __init__(self, + nentity: int, + nrelation: int, + rank: int, + init_size: float = 1e-3, + reg_weight: float = 1e-2, + test_batch_size: int = 1, + method: str = 'beam', + t_norm_name: str = 'prod', + k: int = 5, + query_name_dict: Optional[Dict] = None, + do_sigmoid: bool = False, + do_normalize: bool = False): + super(CQD, self).__init__() + + self.rank = rank + self.nentity = nentity + self.nrelation = nrelation + self.method = method + self.t_norm_name = t_norm_name + self.k = k + self.query_name_dict = query_name_dict + + sizes = (nentity, nrelation) + self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=True) for s in sizes[:2]]) + self.embeddings[0].weight.data *= init_size + self.embeddings[1].weight.data *= init_size + + self.init_size = init_size + self.loss_fn = nn.CrossEntropyLoss(reduction='mean') + self.regularizer = N3(reg_weight) + + self.do_sigmoid = do_sigmoid + self.do_normalize = do_normalize + + # XXX: get rid of this hack + test_batch_size = 1000 + batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) + self.register_buffer('batch_entity_range', batch_entity_range) + + def split(self, + lhs_emb: Tensor, + rel_emb: Tensor, + rhs_emb: Tensor) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]: + lhs = lhs_emb[..., :self.rank], lhs_emb[..., self.rank:] + rel = rel_emb[..., :self.rank], rel_emb[..., self.rank:] + rhs = rhs_emb[..., :self.rank], rhs_emb[..., self.rank:] + return lhs, rel, rhs + + def loss(self, + triples: Tensor) -> Tensor: + (scores_o, scores_s), factors = self.score_candidates(triples) + l_fit = self.loss_fn(scores_o, triples[:, 2]) + self.loss_fn(scores_s, triples[:, 0]) + l_reg = self.regularizer.forward(factors) + return l_fit + l_reg + + def score_candidates(self, + triples: Tensor) -> Tuple[Tuple[Tensor, Tensor], Optional[List[Tensor]]]: + lhs_emb = self.embeddings[0](triples[:, 0]) + rel_emb = self.embeddings[1](triples[:, 1]) + rhs_emb = self.embeddings[0](triples[:, 2]) + to_score = self.embeddings[0].weight + scores_o, _ = self.score_o(lhs_emb, rel_emb, to_score) + scores_s, _ = self.score_s(to_score, rel_emb, rhs_emb) + lhs, rel, rhs = self.split(lhs_emb, rel_emb, rhs_emb) + factors = self.get_factors(lhs, rel, rhs) + return (scores_o, scores_s), factors + + def score_o(self, + lhs_emb: Tensor, + rel_emb: Tensor, + rhs_emb: Tensor, + return_factors: bool = False) -> Tuple[Tensor, Optional[List[Tensor]]]: + lhs, rel, rhs = self.split(lhs_emb, rel_emb, rhs_emb) + score_1 = (lhs[0] * rel[0] - lhs[1] * rel[1]) @ rhs[0].transpose(-1, -2) + score_2 = (lhs[1] * rel[0] + lhs[0] * rel[1]) @ rhs[1].transpose(-1, -2) + factors = self.get_factors(lhs, rel, rhs) if return_factors else None + return score_1 + score_2, factors + + def score_s(self, + lhs_emb: Tensor, + rel_emb: Tensor, + rhs_emb: Tensor, + return_factors: bool = False) -> Tuple[Tensor, Optional[List[Tensor]]]: + lhs, rel, rhs = self.split(lhs_emb, rel_emb, rhs_emb) + score_1 = (rhs[0] * rel[0] + rhs[1] * rel[1]) @ lhs[0].transpose(-1, -2) + score_2 = (rhs[1] * rel[0] - rhs[0] * rel[1]) @ lhs[1].transpose(-1, -2) + factors = self.get_factors(lhs, rel, rhs) if return_factors else None + return score_1 + score_2, factors + + def get_factors(self, + lhs: Tuple[Tensor, Tensor], + rel: Tuple[Tensor, Tensor], + rhs: Tuple[Tensor, Tensor]) -> List[Tensor]: + factors = [] + for term in (lhs, rel, rhs): + factors.append(torch.sqrt(term[0] ** 2 + term[1] ** 2)) + return factors + + def get_full_embeddings(self, queries: Tensor) \ + -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]: + lhs = rel = rhs = None + if torch.sum(queries[:, 0]).item() > 0: + lhs = self.embeddings[0](queries[:, 0]) + if torch.sum(queries[:, 1]).item() > 0: + rel = self.embeddings[1](queries[:, 1]) + if torch.sum(queries[:, 2]).item() > 0: + rhs = self.embeddings[0](queries[:, 2]) + return lhs, rel, rhs + + def batch_t_norm(self, scores: Tensor) -> Tensor: + if self.t_norm_name == CQD.MIN_NORM: + scores = torch.min(scores, dim=1)[0] + elif self.t_norm_name == CQD.PROD_NORM: + scores = torch.prod(scores, dim=1) + else: + raise ValueError(f't_norm must be one of {CQD.NORMS}, got {self.t_norm_name}') + + return scores + + def batch_t_conorm(self, scores: Tensor) -> Tensor: + if self.t_norm_name == CQD.MIN_NORM: + scores = torch.max(scores, dim=1, keepdim=True)[0] + elif self.t_norm_name == CQD.PROD_NORM: + scores = torch.sum(scores, dim=1, keepdim=True) - torch.prod(scores, dim=1, keepdim=True) + else: + raise ValueError(f't_norm must be one of {CQD.NORMS}, got {self.t_norm_name}') + + return scores + + def reduce_query_score(self, atom_scores, conjunction_mask, negation_mask): + batch_size, num_atoms, *extra_dims = atom_scores.shape + + atom_scores = torch.sigmoid(atom_scores) + scores = atom_scores.clone() + scores[negation_mask] = 1 - atom_scores[negation_mask] + + disjunctions = scores[~conjunction_mask].reshape(batch_size, -1, *extra_dims) + conjunctions = scores[conjunction_mask].reshape(batch_size, -1, *extra_dims) + + if disjunctions.shape[1] > 0: + disjunctions = self.batch_t_conorm(disjunctions) + + conjunctions = torch.cat([disjunctions, conjunctions], dim=1) + + t_norm = self.batch_t_norm(conjunctions) + return t_norm + + def forward(self, + positive_sample, + negative_sample, + subsampling_weight, + batch_queries_dict: Dict[Tuple, Tensor], + batch_idxs_dict): + all_idxs = [] + all_scores = [] + + scores = None + + for query_structure, queries in batch_queries_dict.items(): + batch_size = queries.shape[0] + atoms, num_variables, conjunction_mask, negation_mask = query_to_atoms(query_structure, queries) + + all_idxs.extend(batch_idxs_dict[query_structure]) + + # [False, True] + target_mask = torch.sum(atoms == -num_variables, dim=-1) > 0 + + # Offsets identify variables across different batches + var_id_offsets = torch.arange(batch_size, device=atoms.device) * num_variables + var_id_offsets = var_id_offsets.reshape(-1, 1, 1) + + # Replace negative variable IDs with valid identifiers + vars_mask = atoms < 0 + atoms_offset_vars = -atoms + var_id_offsets + + atoms[vars_mask] = atoms_offset_vars[vars_mask] + + head, rel, tail = atoms[..., 0], atoms[..., 1], atoms[..., 2] + head_vars_mask = vars_mask[..., 0] + + with torch.no_grad(): + h_emb_constants = self.embeddings[0](head) + r_emb = self.embeddings[1](rel) + + if 'co' in self.method: + h_emb = h_emb_constants + if num_variables > 1: + # var embedding for ID 0 is unused for ease of implementation + var_embs = nn.Embedding((num_variables * batch_size) + 1, self.rank * 2) + var_embs.weight.data *= self.init_size + + var_embs.to(atoms.device) + optimizer = optim.Adam(var_embs.parameters(), lr=0.1) + prev_loss_value = 1000 + loss_value = 999 + i = 0 + + # CQD-CO optimization loop + while i < 1000 and math.fabs(prev_loss_value - loss_value) > 1e-9: + prev_loss_value = loss_value + + h_emb = h_emb_constants.clone() + # Fill variable positions with optimizable embeddings + h_emb[head_vars_mask] = var_embs(head[head_vars_mask]) + + t_emb = var_embs(tail) + scores, factors = self.score_o(h_emb.unsqueeze(-2), + r_emb.unsqueeze(-2), + t_emb.unsqueeze(-2), + return_factors=True) + + query_score = self.reduce_query_score(scores, + conjunction_mask, + negation_mask) + + loss = - query_score.mean() + self.regularizer.forward(factors) + loss_value = loss.item() + + optimizer.zero_grad() + loss.backward() + optimizer.step() + i += 1 + + with torch.no_grad(): + # Select predicates involving target variable only + conjunction_mask = conjunction_mask[target_mask].reshape(batch_size, -1) + negation_mask = negation_mask[target_mask].reshape(batch_size, -1) + + target_mask = target_mask.unsqueeze(-1).expand_as(h_emb) + emb_size = h_emb.shape[-1] + h_emb = h_emb[target_mask].reshape(batch_size, -1, emb_size) + r_emb = r_emb[target_mask].reshape(batch_size, -1, emb_size) + to_score = self.embeddings[0].weight + + scores, factors = self.score_o(h_emb, r_emb, to_score) + query_score = self.reduce_query_score(scores, + conjunction_mask, + negation_mask) + all_scores.append(query_score) + + scores = torch.cat(all_scores, dim=0) + + elif 'continuous' in self.method: + graph_type = self.query_name_dict[query_structure] + + chain_instructions = create_instructions(atoms[0]) + chains = [] + + for atom in range(len(atoms[0])): + part = atoms[:, atom, :] + chain = self.get_full_embeddings(part) + chains.append(chain) + + scores = top_k_selection(chains, + chain_instructions, + graph_type, + # score_o takes lhs, rel, rhs + scoring_function=lambda rel_, lhs_, rhs_: self.score_o(lhs_, rel_, rhs_)[0], + forward_emb=lambda lhs_, rel_: self.score_o(lhs_, rel_, self.embeddings[0].weight)[0], + entity_embeddings=self.embeddings[0], + candidates=self.k, + t_norm=self.t_norm_name, + batch_size=1, + scores_normalize='default') + + elif 'discrete' in self.method: + graph_type = self.query_name_dict[query_structure] + + def t_norm(a: Tensor, b: Tensor) -> Tensor: + return torch.minimum(a, b) + + def t_conorm(a: Tensor, b: Tensor) -> Tensor: + return torch.maximum(a, b) + + def negation(a: Tensor) -> Tensor: + return 1.0 - a + + if self.t_norm_name == CQD.PROD_NORM: + def t_norm(a: Tensor, b: Tensor) -> Tensor: + return a * b + + def t_conorm(a: Tensor, b: Tensor) -> Tensor: + return 1 - ((1 - a) * (1 - b)) + + def normalize(scores_: Tensor) -> Tensor: + scores_ = scores_ - scores_.min(1, keepdim=True)[0] + scores_ = scores_ / scores_.max(1, keepdim=True)[0] + return scores_ + + def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor: + res, _ = self.score_o(lhs_, rel_, rhs_) + if self.do_sigmoid is True: + res = torch.sigmoid(res) + if self.do_normalize is True: + res = normalize(res) + return res + + if graph_type == "1p": + scores = d2.query_1p(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function) + elif graph_type == "2p": + scores = d2.query_2p(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + k=self.k, t_norm=t_norm) + elif graph_type == "3p": + scores = d2.query_3p(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + k=self.k, t_norm=t_norm) + elif graph_type == "2i": + scores = d2.query_2i(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, t_norm=t_norm) + elif graph_type == "3i": + scores = d2.query_3i(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, t_norm=t_norm) + elif graph_type == "pi": + scores = d2.query_pi(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + k=self.k, t_norm=t_norm) + elif graph_type == "ip": + scores = d2.query_ip(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + t_norm=t_norm) + elif graph_type == "2u-DNF": + scores = d2.query_2u_dnf(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + t_conorm=t_conorm) + elif graph_type == "up-DNF": + scores = d2.query_up_dnf(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + t_norm=t_norm, t_conorm=t_conorm) + elif graph_type == "2in": + scores = d2.query_2in(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + t_norm=t_norm, negation=negation) + elif graph_type == "3in": + scores = d2.query_3in(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + t_norm=t_norm, negation=negation) + elif graph_type == "pin": + scores = d2.query_pin(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + k=self.k, t_norm=t_norm, negation=negation) + elif graph_type == "pni": + scores = d2.query_pni_v2(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + k=self.k, t_norm=t_norm, negation=negation) + elif graph_type == "inp": + scores = d2.query_inp(entity_embeddings=self.embeddings[0], + predicate_embeddings=self.embeddings[1], + queries=queries, + scoring_function=scoring_function, + t_norm=t_norm, negation=negation) + else: + raise ValueError(f'Unknown query type: {graph_type}') + + return None, scores, None, all_idxs diff --git a/cqd/dataloader.py b/cqd/dataloader.py new file mode 100644 index 0000000..b7cd80e --- /dev/null +++ b/cqd/dataloader.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import torch + +from torch.utils.data import Dataset +from util import flatten + + +class CQDTrainDataset(Dataset): + def __init__(self, queries, nentity, nrelation, negative_sample_size, answer): + # queries is a list of (query, query_structure) pairs + self.len = len(queries) + self.queries = queries + self.nentity = nentity + self.nrelation = nrelation + self.negative_sample_size = negative_sample_size + self.answer = answer + + self.qa_lst = [] + for q, qs in queries: + for a in self.answer[q]: + qa_entry = (qs, q, a) + self.qa_lst += [qa_entry] + + self.qa_len = len(self.qa_lst) + + def __len__(self): + # return self.len + return self.qa_len + + def __getitem__(self, idx): + # query = self.queries[idx][0] + query = self.qa_lst[idx][1] + + # query_structure = self.queries[idx][1] + query_structure = self.qa_lst[idx][0] + + # tail = np.random.choice(list(self.answer[query])) + tail = self.qa_lst[idx][2] + + # subsampling_weight = self.count[query] + # subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight])) + subsampling_weight = torch.tensor([1.0]) + + negative_sample_list = [] + negative_sample_size = 0 + while negative_sample_size < self.negative_sample_size: + negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size * 2) + mask = np.in1d( + negative_sample, + self.answer[query], + assume_unique=True, + invert=True + ) + negative_sample = negative_sample[mask] + negative_sample_list.append(negative_sample) + negative_sample_size += negative_sample.size + negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] + negative_sample = torch.from_numpy(negative_sample) + positive_sample = torch.LongTensor([tail]) + return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure + + @staticmethod + def collate_fn(data): + positive_sample = torch.cat([_[0] for _ in data], dim=0) + negative_sample = torch.stack([_[1] for _ in data], dim=0) + subsample_weight = torch.cat([_[2] for _ in data], dim=0) + query = [_[3] for _ in data] + query_structure = [_[4] for _ in data] + return positive_sample, negative_sample, subsample_weight, query, query_structure diff --git a/cqd/discrete.py b/cqd/discrete.py new file mode 100644 index 0000000..f4d02c1 --- /dev/null +++ b/cqd/discrete.py @@ -0,0 +1,492 @@ +# -*- coding: utf-8 -*- + +import torch +from torch import nn, Tensor + +from typing import Callable, Tuple, Optional + + +def score_candidates(s_emb: Tensor, + p_emb: Tensor, + candidates_emb: Tensor, + k: Optional[int], + entity_embeddings: nn.Module, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor]) -> Tuple[Tensor, Optional[Tensor]]: + + batch_size = max(s_emb.shape[0], p_emb.shape[0]) + embedding_size = s_emb.shape[1] + + def reshape(emb: Tensor) -> Tensor: + if emb.shape[0] < batch_size: + n_copies = batch_size // emb.shape[0] + emb = emb.reshape(-1, 1, embedding_size).repeat(1, n_copies, 1).reshape(-1, embedding_size) + return emb + + s_emb = reshape(s_emb) + p_emb = reshape(p_emb) + nb_entities = candidates_emb.shape[0] + + x_k_emb_3d = None + + # [B, N] + atom_scores_2d = scoring_function(s_emb, p_emb, candidates_emb) + atom_k_scores_2d = atom_scores_2d + + if k is not None: + k_ = min(k, nb_entities) + + # [B, K], [B, K] + atom_k_scores_2d, atom_k_indices = torch.topk(atom_scores_2d, k=k_, dim=1) + + # [B, K, E] + x_k_emb_3d = entity_embeddings(atom_k_indices) + + return atom_k_scores_2d, x_k_emb_3d + + +def query_1p(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor]) -> Tensor: + s_emb = entity_embeddings(queries[:, 0]) + p_emb = predicate_embeddings(queries[:, 1]) + candidates_emb = entity_embeddings.weight + + assert queries.shape[1] == 2 + + res, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, + candidates_emb=candidates_emb, k=None, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + return res + + +def query_2p(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + s_emb = entity_embeddings(queries[:, 0]) + p1_emb = predicate_embeddings(queries[:, 1]) + p2_emb = predicate_embeddings(queries[:, 2]) + + candidates_emb = entity_embeddings.weight + nb_entities = candidates_emb.shape[0] + + batch_size = s_emb.shape[0] + emb_size = s_emb.shape[1] + + # [B, K], [B, K, E] + atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb, + candidates_emb=candidates_emb, k=k, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + # [B * K, E] + x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size) + + # [B * K, N] + atom2_scores_2d, _ = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb, + candidates_emb=candidates_emb, k=None, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + # [B, K] -> [B, K, N] + atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities) + # [B * K, N] -> [B, K, N] + atom2_scores_3d = atom2_scores_2d.reshape(batch_size, -1, nb_entities) + + res = t_norm(atom1_scores_3d, atom2_scores_3d) + + # [B, K, N] -> [B, N] + res, _ = torch.max(res, dim=1) + return res + + +def query_2pn(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + s_emb = entity_embeddings(queries[:, 0]) + p1_emb = predicate_embeddings(queries[:, 1]) + p2_emb = predicate_embeddings(queries[:, 2]) + + candidates_emb = entity_embeddings.weight + nb_entities = candidates_emb.shape[0] + + batch_size = s_emb.shape[0] + emb_size = s_emb.shape[1] + + # [B, K], [B, K, E] + atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb, + candidates_emb=candidates_emb, k=k, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + # [B * K, E] + x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size) + + # [B * K, N] + atom2_scores_2d, _ = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb, + candidates_emb=candidates_emb, k=None, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + atom2_scores_2d = negation(atom2_scores_2d) + + # [B, K] -> [B, K, N] + atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities) + # [B * K, N] -> [B, K, N] + atom2_scores_3d = atom2_scores_2d.reshape(batch_size, -1, nb_entities) + + res = t_norm(atom1_scores_3d, atom2_scores_3d) + + # [B, K, N] -> [B, N] + res, _ = torch.max(res, dim=1) + return res + + +def query_3p(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + s_emb = entity_embeddings(queries[:, 0]) + p1_emb = predicate_embeddings(queries[:, 1]) + p2_emb = predicate_embeddings(queries[:, 2]) + p3_emb = predicate_embeddings(queries[:, 3]) + + candidates_emb = entity_embeddings.weight + nb_entities = candidates_emb.shape[0] + + batch_size = s_emb.shape[0] + emb_size = s_emb.shape[1] + + # [B, K], [B, K, E] + atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb, + candidates_emb=candidates_emb, k=k, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + # [B * K, E] + x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size) + + # [B * K, K], [B * K, K, E] + atom2_k_scores_2d, x2_k_emb_3d = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb, + candidates_emb=candidates_emb, k=k, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + # [B * K * K, E] + x2_k_emb_2d = x2_k_emb_3d.reshape(-1, emb_size) + + # [B * K * K, N] + atom3_scores_2d, _ = score_candidates(s_emb=x2_k_emb_2d, p_emb=p3_emb, + candidates_emb=candidates_emb, k=None, + entity_embeddings=entity_embeddings, + scoring_function=scoring_function) + + # [B, K] -> [B, K, N] + atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities) + + # [B * K, K] -> [B, K * K, N] + atom2_scores_3d = atom2_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities) + + # [B * K * K, N] -> [B, K * K, N] + atom3_scores_3d = atom3_scores_2d.reshape(batch_size, -1, nb_entities) + + atom1_scores_3d = atom1_scores_3d.repeat(1, atom3_scores_3d.shape[1] // atom1_scores_3d.shape[1], 1) + + res = t_norm(atom1_scores_3d, atom2_scores_3d) + res = t_norm(res, atom3_scores_3d) + + # [B, K, N] -> [B, N] + res, _ = torch.max(res, dim=1) + return res + + +def query_2i(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:2], scoring_function=scoring_function) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 2:4], scoring_function=scoring_function) + + res = t_norm(scores_1, scores_2) + + return res + + +def query_3i(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:2], scoring_function=scoring_function) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 2:4], scoring_function=scoring_function) + scores_3 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 4:6], scoring_function=scoring_function) + + res = t_norm(scores_1, scores_2) + res = t_norm(res, scores_3) + + return res + + +def query_ip(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + # [B, N] + scores_1 = query_2i(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:4], scoring_function=scoring_function, t_norm=t_norm) + + # [B, E] + p_emb = predicate_embeddings(queries[:, 4]) + + batch_size = p_emb.shape[0] + emb_size = p_emb.shape[1] + + # [N, E] + e_emb = entity_embeddings.weight + nb_entities = e_emb.shape[0] + + # [B * N, E] + s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size) + + # [B * N, N] + scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None, + entity_embeddings=entity_embeddings, scoring_function=scoring_function) + + # [B, N, N] + scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities) + scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities) + + res = t_norm(scores_1, scores_2) + res, _ = torch.max(res, dim=1) + + return res + + +def query_pi(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 3:5], scoring_function=scoring_function) + + res = t_norm(scores_1, scores_2) + + return res + + +def query_2in(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:2], scoring_function=scoring_function) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 2:4], scoring_function=scoring_function) + + scores_2 = negation(scores_2) + + res = t_norm(scores_1, scores_2) + + return res + + +def query_3in(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:2], scoring_function=scoring_function) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 2:4], scoring_function=scoring_function) + scores_3 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 4:6], scoring_function=scoring_function) + + scores_3 = negation(scores_3) + + res = t_norm(scores_1, scores_2) + res = t_norm(res, scores_3) + + return res + + +def query_inp(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + # [B, N] + scores_1 = query_2in(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:4], scoring_function=scoring_function, t_norm=t_norm, negation=negation) + + # [B, E] + p_emb = predicate_embeddings(queries[:, 5]) + + batch_size = p_emb.shape[0] + emb_size = p_emb.shape[1] + + # [N, E] + e_emb = entity_embeddings.weight + nb_entities = e_emb.shape[0] + + # [B * N, E] + s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size) + + # [B * N, N] + scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None, + entity_embeddings=entity_embeddings, scoring_function=scoring_function) + + # [B, N, N] + scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities) + scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities) + + res = t_norm(scores_1, scores_2) + res, _ = torch.max(res, dim=1) + + return res + + +def query_pin(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 3:5], scoring_function=scoring_function) + + scores_2 = negation(scores_2) + + res = t_norm(scores_1, scores_2) + + return res + + +def query_pni_v1(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 4:6], scoring_function=scoring_function) + + scores_1 = negation(scores_1) + + res = t_norm(scores_1, scores_2) + return res + + +def query_pni_v2(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, + t_norm: Callable[[Tensor, Tensor], Tensor], + negation: Callable[[Tensor], Tensor]) -> Tensor: + + scores_1 = query_2pn(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:3], scoring_function=scoring_function, k=k, + t_norm=t_norm, negation=negation) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 4:6], scoring_function=scoring_function) + + res = t_norm(scores_1, scores_2) + return res + + +def query_2u_dnf(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_conorm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + + scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:2], scoring_function=scoring_function) + scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 2:4], scoring_function=scoring_function) + + res = t_conorm(scores_1, scores_2) + + return res + + +def query_up_dnf(entity_embeddings: nn.Module, + predicate_embeddings: nn.Module, + queries: Tensor, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + t_norm: Callable[[Tensor, Tensor], Tensor], + t_conorm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: + # [B, N] + scores_1 = query_2u_dnf(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, + queries=queries[:, 0:4], scoring_function=scoring_function, t_conorm=t_conorm) + + # [B, E] + p_emb = predicate_embeddings(queries[:, 5]) + + batch_size = p_emb.shape[0] + emb_size = p_emb.shape[1] + + # [N, E] + e_emb = entity_embeddings.weight + nb_entities = e_emb.shape[0] + + # [B * N, E] + s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size) + + # [B * N, N] + scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None, + entity_embeddings=entity_embeddings, scoring_function=scoring_function) + + # [B, N, N] + scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities) + scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities) + + res = t_norm(scores_1, scores_2) + res, _ = torch.max(res, dim=1) + + return res diff --git a/cqd/util.py b/cqd/util.py new file mode 100644 index 0000000..dee0d78 --- /dev/null +++ b/cqd/util.py @@ -0,0 +1,383 @@ +# -*- coding: utf-8 -*- + +import torch +from torch import Tensor + +import numpy as np + +from tqdm import tqdm + +from typing import List, Tuple, Callable + + +def flatten_structure(query_structure): + if type(query_structure) == str: + return [query_structure] + + flat_structure = [] + for element in query_structure: + flat_structure.extend(flatten_structure(element)) + + return flat_structure + + +def query_to_atoms(query_structure, flat_ids): + flat_structure = flatten_structure(query_structure) + batch_size, query_length = flat_ids.shape + assert len(flat_structure) == query_length + + query_triples = [] + variable = 0 + previous = flat_ids[:, 0] + conjunction_mask = [] + negation_mask = [] + + for i in range(1, query_length): + if flat_structure[i] == 'r': + variable -= 1 + triples = torch.empty(batch_size, 3, + device=flat_ids.device, + dtype=torch.long) + triples[:, 0] = previous + triples[:, 1] = flat_ids[:, i] + triples[:, 2] = variable + + query_triples.append(triples) + previous = variable + conjunction_mask.append(True) + negation_mask.append(False) + elif flat_structure[i] == 'e': + previous = flat_ids[:, i] + variable += 1 + elif flat_structure[i] == 'u': + conjunction_mask = [False] * len(conjunction_mask) + elif flat_structure[i] == 'n': + negation_mask[-1] = True + + atoms = torch.stack(query_triples, dim=1) + num_variables = variable * -1 + conjunction_mask = torch.tensor(conjunction_mask).unsqueeze(0).expand(batch_size, -1) + negation_mask = torch.tensor(negation_mask).unsqueeze(0).expand(batch_size, -1) + + return atoms, num_variables, conjunction_mask, negation_mask + + +def create_instructions(chains): + instructions = [] + + prev_start = None + prev_end = None + + path_stack = [] + start_flag = True + for chain_ind, chain in enumerate(chains): + if start_flag: + prev_end = chain[-1] + start_flag = False + continue + + if prev_end == chain[0]: + instructions.append(f"hop_{chain_ind-1}_{chain_ind}") + prev_end = chain[-1] + prev_start = chain[0] + + elif prev_end == chain[-1]: + + prev_start = chain[0] + prev_end = chain[-1] + + instructions.append(f"intersect_{chain_ind-1}_{chain_ind}") + else: + path_stack.append(([prev_start, prev_end],chain_ind-1)) + prev_start = chain[0] + prev_end = chain[-1] + start_flag = False + continue + + if len(path_stack) > 0: + + path_prev_start = path_stack[-1][0][0] + path_prev_end = path_stack[-1][0][-1] + + if path_prev_end == chain[-1]: + + prev_start = chain[0] + prev_end = chain[-1] + + instructions.append(f"intersect_{path_stack[-1][1]}_{chain_ind}") + path_stack.pop() + continue + + ans = [] + for inst in instructions: + if ans: + + if 'inter' in inst and ('inter' in ans[-1]): + last_ind = inst.split("_")[-1] + ans[-1] = ans[-1]+f"_{last_ind}" + else: + ans.append(inst) + + else: + ans.append(inst) + + instructions = ans + return instructions + + +def t_norm_fn(tens_1: Tensor, tens_2: Tensor, t_norm: str = 'min') -> Tensor: + if 'min' in t_norm: + return torch.min(tens_1, tens_2) + elif 'prod' in t_norm: + return tens_1 * tens_2 + + +def t_conorm_fn(tens_1: Tensor, tens_2: Tensor, t_norm: str = 'min') -> Tensor: + if 'min' in t_norm: + return torch.max(tens_1, tens_2) + elif 'prod' in t_norm: + return (tens_1 + tens_2) - (tens_1 * tens_2) + + +def make_batches(size: int, batch_size: int) -> List[Tuple[int, int]]: + max_batch = int(np.ceil(size / float(batch_size))) + res = [(i * batch_size, min(size, (i + 1) * batch_size)) for i in range(0, max_batch)] + return res + + +def get_best_candidates(rel: Tensor, + arg1: Tensor, + forward_emb: Callable[[Tensor, Tensor], Tensor], + entity_embeddings: Callable[[Tensor], Tensor], + candidates: int = 5, + last_step: bool = False) -> Tuple[Tensor, Tensor]: + batch_size, embedding_size = rel.shape[0], rel.shape[1] + + # [B, N] + scores = forward_emb(arg1, rel) + + if not last_step: + # [B, K], [B, K] + k = min(candidates, scores.shape[1]) + z_scores, z_indices = torch.topk(scores, k=k, dim=1) + # [B, K, E] + z_emb = entity_embeddings(z_indices) + + # XXX: move before return + assert z_emb.shape[0] == batch_size + assert z_emb.shape[2] == embedding_size + else: + z_scores = scores + + z_indices = torch.arange(z_scores.shape[1]).view(1, -1).repeat(z_scores.shape[0], 1).to(rel.device) + z_emb = entity_embeddings(z_indices) + + return z_scores, z_emb + + +def top_k_selection(chains, + chain_instructions, + graph_type, + scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + forward_emb: Callable[[Tensor, Tensor], Tensor], + entity_embeddings: Callable[[Tensor], Tensor], + candidates: int = 5, + t_norm: str = 'min', + batch_size: int = 1, + scores_normalize: str = 'default'): + res = None + + if 'disj' in graph_type: + objective = t_conorm_fn + else: + objective = t_norm_fn + + nb_queries, embedding_size = chains[0][0].shape[0], chains[0][0].shape[1] + + scores = None + + batches = make_batches(nb_queries, batch_size) + + for batch in tqdm(batches): + + nb_branches = 1 + nb_ent = 0 + batch_scores = None + candidate_cache = {} + + batch_size = batch[1] - batch[0] + dnf_flag = False + if 'disj' in graph_type: + dnf_flag = True + + for inst_ind, inst in enumerate(chain_instructions): + with torch.no_grad(): + if 'hop' in inst: + + ind_1 = int(inst.split("_")[-2]) + ind_2 = int(inst.split("_")[-1]) + + indices = [ind_1, ind_2] + + if objective == t_conorm_fn and dnf_flag: + objective = t_norm_fn + + last_hop = False + for hop_num, ind in enumerate(indices): + last_step = (inst_ind == len(chain_instructions) - 1) and last_hop + + lhs, rel, rhs = chains[ind] + + if lhs is not None: + lhs = lhs[batch[0]:batch[1]] + else: + # print("MTA BRAT") + batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"] + lhs = lhs_3d.view(-1, embedding_size) + + rel = rel[batch[0]:batch[1]] + rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1) + rel = rel.view(-1, embedding_size) + + if f"rhs_{ind}" not in candidate_cache: + + # print("STTEEE MTA") + z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step) + + # [Num_queries * Candidates^K] + z_scores_1d = z_scores.view(-1) + if 'disj' in graph_type or scores_normalize: + z_scores_1d = torch.sigmoid(z_scores_1d) + + # B * S + nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1] + nb_branches = nb_sources // batch_size + if not last_step: + batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm) + else: + nb_ent = rhs_3d.shape[1] + batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm) + + candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d) + + if not last_hop: + candidate_cache[f"lhs_{indices[hop_num + 1]}"] = (batch_scores, rhs_3d) + + else: + batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"] + candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d) + last_hop = True + continue + + last_hop = True + + elif 'inter' in inst: + ind_1 = int(inst.split("_")[-2]) + ind_2 = int(inst.split("_")[-1]) + + indices = [ind_1, ind_2] + + if objective == t_norm_fn and dnf_flag: + objective = t_conorm_fn + + if len(inst.split("_")) > 3: + ind_1 = int(inst.split("_")[-3]) + ind_2 = int(inst.split("_")[-2]) + ind_3 = int(inst.split("_")[-1]) + + indices = [ind_1, ind_2, ind_3] + + for intersection_num, ind in enumerate(indices): + last_step = (inst_ind == len(chain_instructions) - 1) # and ind == indices[0] + + lhs, rel, rhs = chains[ind] + + if lhs is not None: + lhs = lhs[batch[0]:batch[1]] + lhs = lhs.view(-1, 1, embedding_size).repeat(1, nb_branches, 1) + lhs = lhs.view(-1, embedding_size) + + else: + batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"] + lhs = lhs_3d.view(-1, embedding_size) + nb_sources = lhs_3d.shape[0] * lhs_3d.shape[1] + nb_branches = nb_sources // batch_size + + rel = rel[batch[0]:batch[1]] + rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1) + rel = rel.view(-1, embedding_size) + + if intersection_num > 0 and 'disj' in graph_type: + batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"] + rhs = rhs_3d.view(-1, embedding_size) + z_scores = scoring_function(rel, lhs, rhs) + + z_scores_1d = z_scores.view(-1) + if 'disj' in graph_type or scores_normalize: + z_scores_1d = torch.sigmoid(z_scores_1d) + + batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores, t_norm) + + continue + + if f"rhs_{ind}" not in candidate_cache or last_step: + z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step) + + # [B * Candidates^K] or [B, S-1, N] + z_scores_1d = z_scores.view(-1) + if 'disj' in graph_type or scores_normalize: + z_scores_1d = torch.sigmoid(z_scores_1d) + + if not last_step: + nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1] + nb_branches = nb_sources // batch_size + + if not last_step: + batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm) + else: + if ind == indices[0]: + nb_ent = rhs_3d.shape[1] + else: + nb_ent = 1 + + batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm) + nb_ent = rhs_3d.shape[1] + + candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d) + + if ind == indices[0] and 'disj' in graph_type: + count = len(indices) - 1 + iterator = 1 + while count > 0: + candidate_cache[f"rhs_{indices[intersection_num + iterator]}"] = ( + batch_scores, rhs_3d) + iterator += 1 + count -= 1 + + if ind == indices[-1]: + candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d) + else: + batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"] + candidate_cache[f"rhs_{ind + 1}"] = (batch_scores, rhs_3d) + + last_hop = True + del lhs, rel + continue + + del lhs, rel, rhs, rhs_3d, z_scores_1d, z_scores + + if batch_scores is not None: + # [B * entites * S ] + # S == K**(V-1) + scores_2d = batch_scores.view(batch_size, -1, nb_ent) + res, _ = torch.max(scores_2d, dim=1) + scores = res if scores is None else torch.cat([scores, res]) + + del batch_scores, scores_2d, res, candidate_cache + + else: + assert False, "Batch Scores are empty: an error went uncaught." + + res = scores + + return res diff --git a/main.py b/main.py index 6b76007..8067c70 100755 --- a/main.py +++ b/main.py @@ -8,19 +8,21 @@ import json import logging import os -import random import numpy as np import torch from torch.utils.data import DataLoader from models import KGReasoning +from cqd import CQD + from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator +from cqd.dataloader import CQDTrainDataset + from tensorboardX import SummaryWriter -import time + import pickle from collections import defaultdict -from tqdm import tqdm -from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple +from util import flatten_query, parse_time, set_global_seed, eval_tuple query_name_dict = {('e',('r',)): '1p', ('e', ('r', 'r')): '2p', @@ -42,6 +44,7 @@ name_query_dict = {value: key for key, value in query_name_dict.items()} all_tasks = list(name_query_dict.keys()) # ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', '2u-DM', 'up-DNF', 'up-DM'] + def parse_args(args=None): parser = argparse.ArgumentParser( description='Training and Testing Knowledge Graph Embedding Models', @@ -74,8 +77,17 @@ def parse_args(args=None): parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET') parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET') - parser.add_argument('--geo', default='vec', type=str, choices=['vec', 'box', 'beta'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE') + parser.add_argument('--geo', default='vec', type=str, choices=['vec', 'box', 'beta', 'cqd'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE, cqd for CQD') parser.add_argument('--print_on_screen', action='store_true') + + parser.add_argument('--reg_weight', default=1e-3, type=float) + parser.add_argument('--optimizer', choices=['adam', 'adagrad'], default='adam') + parser.add_argument('--cqd-type', '--cqd', default='co', type=str, choices=['continuous', 'discrete']) + parser.add_argument('--cqd-t-norm', default=CQD.PROD_NORM, type=str, choices=CQD.NORMS) + parser.add_argument('--cqd-k', default=5, type=int) + parser.add_argument('--cqd-sigmoid-scores', '--cqd-sigmoid', action='store_true', default=False) + parser.add_argument('--cqd-normalize-scores', '--cqd-normalize', action='store_true', default=False) + parser.add_argument('--use-qa-iterator', action='store_true', default=False) parser.add_argument('--tasks', default='1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task") parser.add_argument('--seed', default=0, type=int, help="random seed") @@ -87,6 +99,7 @@ def parse_args(args=None): return parser.parse_args(args) + def save_model(model, optimizer, save_variable_list, args): ''' Save the parameters of the model and the optimizer, @@ -104,6 +117,7 @@ def save_model(model, optimizer, save_variable_list, args): os.path.join(args.save_path, 'checkpoint') ) + def set_logger(args): ''' Write logs to console and log file @@ -127,6 +141,7 @@ def set_logger(args): console.setFormatter(formatter) logging.getLogger('').addHandler(console) + def log_metrics(mode, step, metrics): ''' Print the evaluation logs @@ -134,6 +149,7 @@ def log_metrics(mode, step, metrics): for metric in metrics: logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric])) + def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer): ''' Evaluate queries in dataloader @@ -141,7 +157,7 @@ def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, m average_metrics = defaultdict(float) all_metrics = defaultdict(float) - metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict) + metrics = KGReasoning.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict) num_query_structures = 0 num_queries = 0 for query_structure in metrics: @@ -161,7 +177,8 @@ def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, m log_metrics('%s average'%mode, step, average_metrics) return all_metrics - + + def load_data(args, tasks): ''' Load queries and remove queries not in tasks @@ -193,6 +210,7 @@ def load_data(args, tasks): return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers + def main(args): set_global_seed(args.seed) tasks = args.tasks.split('.') @@ -263,21 +281,26 @@ def main(args): else: train_other_queries[query_structure] = train_queries[query_structure] train_path_queries = flatten_query(train_path_queries) + + TrainDatasetClass = TrainDataset + if args.use_qa_iterator is True: + TrainDatasetClass = CQDTrainDataset + train_path_iterator = SingledirectionalOneShotIterator(DataLoader( - TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers), + TrainDatasetClass(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers), batch_size=args.batch_size, shuffle=True, num_workers=args.cpu_num, - collate_fn=TrainDataset.collate_fn + collate_fn=TrainDatasetClass.collate_fn )) if len(train_other_queries) > 0: train_other_queries = flatten_query(train_other_queries) train_other_iterator = SingledirectionalOneShotIterator(DataLoader( - TrainDataset(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers), + TrainDatasetClass(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers), batch_size=args.batch_size, shuffle=True, num_workers=args.cpu_num, - collate_fn=TrainDataset.collate_fn + collate_fn=TrainDatasetClass.collate_fn )) else: train_other_iterator = None @@ -315,18 +338,39 @@ def main(args): collate_fn=TestDataset.collate_fn ) - model = KGReasoning( - nentity=nentity, - nrelation=nrelation, - hidden_dim=args.hidden_dim, - gamma=args.gamma, - geo=args.geo, - use_cuda = args.cuda, - box_mode=eval_tuple(args.box_mode), - beta_mode = eval_tuple(args.beta_mode), - test_batch_size=args.test_batch_size, - query_name_dict = query_name_dict - ) + if args.geo == 'cqd': + model = CQD(nentity, + nrelation, + rank=args.hidden_dim, + test_batch_size=args.test_batch_size, + reg_weight=args.reg_weight, + query_name_dict=query_name_dict, + method=args.cqd_type, + t_norm_name=args.cqd_t_norm, + k=args.cqd_k, + do_sigmoid=args.cqd_sigmoid_scores, + do_normalize=args.cqd_normalize_scores) + else: + model = KGReasoning( + nentity=nentity, + nrelation=nrelation, + hidden_dim=args.hidden_dim, + gamma=args.gamma, + geo=args.geo, + use_cuda = args.cuda, + box_mode=eval_tuple(args.box_mode), + beta_mode = eval_tuple(args.beta_mode), + test_batch_size=args.test_batch_size, + query_name_dict = query_name_dict + ) + + name_to_optimizer = { + 'adam': torch.optim.Adam, + 'adagrad': torch.optim.Adagrad + } + + assert args.optimizer in name_to_optimizer + OptimizerClass = name_to_optimizer[args.optimizer] logging.info('Model Parameter Configuration:') num_params = 0 @@ -341,7 +385,7 @@ def main(args): if args.do_train: current_learning_rate = args.learning_rate - optimizer = torch.optim.Adam( + optimizer = OptimizerClass( filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate ) @@ -349,7 +393,8 @@ def main(args): if args.checkpoint_path is not None: logging.info('Loading checkpoint %s...' % args.checkpoint_path) - checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint')) + checkpoint = torch.load(os.path.join(args.checkpoint_path, 'checkpoint'), + map_location=torch.device('cpu') if not args.cuda else None) init_step = checkpoint['step'] model.load_state_dict(checkpoint['model_state_dict']) @@ -396,7 +441,7 @@ def main(args): if step >= warm_up_steps: current_learning_rate = current_learning_rate / 5 logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step)) - optimizer = torch.optim.Adam( + optimizer = OptimizerClass( filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate ) @@ -445,5 +490,6 @@ def main(args): logging.info("Training finished!!") + if __name__ == '__main__': - main(parse_args()) \ No newline at end of file + main(parse_args()) From ad5bcb2904c9327c72728cd03dc984ac89126747 Mon Sep 17 00:00:00 2001 From: pminervini Date: Fri, 4 Jun 2021 16:33:21 +0200 Subject: [PATCH 02/25] CQD --- cqd/base.py | 33 --------- cqd/discrete.py | 183 ------------------------------------------------ 2 files changed, 216 deletions(-) diff --git a/cqd/base.py b/cqd/base.py index cce7b83..4152531 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -306,9 +306,6 @@ def t_norm(a: Tensor, b: Tensor) -> Tensor: def t_conorm(a: Tensor, b: Tensor) -> Tensor: return torch.maximum(a, b) - def negation(a: Tensor) -> Tensor: - return 1.0 - a - if self.t_norm_name == CQD.PROD_NORM: def t_norm(a: Tensor, b: Tensor) -> Tensor: return a * b @@ -380,36 +377,6 @@ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor: queries=queries, scoring_function=scoring_function, t_norm=t_norm, t_conorm=t_conorm) - elif graph_type == "2in": - scores = d2.query_2in(entity_embeddings=self.embeddings[0], - predicate_embeddings=self.embeddings[1], - queries=queries, - scoring_function=scoring_function, - t_norm=t_norm, negation=negation) - elif graph_type == "3in": - scores = d2.query_3in(entity_embeddings=self.embeddings[0], - predicate_embeddings=self.embeddings[1], - queries=queries, - scoring_function=scoring_function, - t_norm=t_norm, negation=negation) - elif graph_type == "pin": - scores = d2.query_pin(entity_embeddings=self.embeddings[0], - predicate_embeddings=self.embeddings[1], - queries=queries, - scoring_function=scoring_function, - k=self.k, t_norm=t_norm, negation=negation) - elif graph_type == "pni": - scores = d2.query_pni_v2(entity_embeddings=self.embeddings[0], - predicate_embeddings=self.embeddings[1], - queries=queries, - scoring_function=scoring_function, - k=self.k, t_norm=t_norm, negation=negation) - elif graph_type == "inp": - scores = d2.query_inp(entity_embeddings=self.embeddings[0], - predicate_embeddings=self.embeddings[1], - queries=queries, - scoring_function=scoring_function, - t_norm=t_norm, negation=negation) else: raise ValueError(f'Unknown query type: {graph_type}') diff --git a/cqd/discrete.py b/cqd/discrete.py index f4d02c1..c4f9c88 100644 --- a/cqd/discrete.py +++ b/cqd/discrete.py @@ -106,53 +106,6 @@ def query_2p(entity_embeddings: nn.Module, return res -def query_2pn(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - k: int, - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - s_emb = entity_embeddings(queries[:, 0]) - p1_emb = predicate_embeddings(queries[:, 1]) - p2_emb = predicate_embeddings(queries[:, 2]) - - candidates_emb = entity_embeddings.weight - nb_entities = candidates_emb.shape[0] - - batch_size = s_emb.shape[0] - emb_size = s_emb.shape[1] - - # [B, K], [B, K, E] - atom1_k_scores_2d, x1_k_emb_3d = score_candidates(s_emb=s_emb, p_emb=p1_emb, - candidates_emb=candidates_emb, k=k, - entity_embeddings=entity_embeddings, - scoring_function=scoring_function) - - # [B * K, E] - x1_k_emb_2d = x1_k_emb_3d.reshape(-1, emb_size) - - # [B * K, N] - atom2_scores_2d, _ = score_candidates(s_emb=x1_k_emb_2d, p_emb=p2_emb, - candidates_emb=candidates_emb, k=None, - entity_embeddings=entity_embeddings, - scoring_function=scoring_function) - - atom2_scores_2d = negation(atom2_scores_2d) - - # [B, K] -> [B, K, N] - atom1_scores_3d = atom1_k_scores_2d.reshape(batch_size, -1, 1).repeat(1, 1, nb_entities) - # [B * K, N] -> [B, K, N] - atom2_scores_3d = atom2_scores_2d.reshape(batch_size, -1, nb_entities) - - res = t_norm(atom1_scores_3d, atom2_scores_3d) - - # [B, K, N] -> [B, N] - res, _ = torch.max(res, dim=1) - return res - - def query_3p(entity_embeddings: nn.Module, predicate_embeddings: nn.Module, queries: Tensor, @@ -303,142 +256,6 @@ def query_pi(entity_embeddings: nn.Module, return res -def query_2in(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 0:2], scoring_function=scoring_function) - scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 2:4], scoring_function=scoring_function) - - scores_2 = negation(scores_2) - - res = t_norm(scores_1, scores_2) - - return res - - -def query_3in(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - scores_1 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 0:2], scoring_function=scoring_function) - scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 2:4], scoring_function=scoring_function) - scores_3 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 4:6], scoring_function=scoring_function) - - scores_3 = negation(scores_3) - - res = t_norm(scores_1, scores_2) - res = t_norm(res, scores_3) - - return res - - -def query_inp(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - # [B, N] - scores_1 = query_2in(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 0:4], scoring_function=scoring_function, t_norm=t_norm, negation=negation) - - # [B, E] - p_emb = predicate_embeddings(queries[:, 5]) - - batch_size = p_emb.shape[0] - emb_size = p_emb.shape[1] - - # [N, E] - e_emb = entity_embeddings.weight - nb_entities = e_emb.shape[0] - - # [B * N, E] - s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size) - - # [B * N, N] - scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None, - entity_embeddings=entity_embeddings, scoring_function=scoring_function) - - # [B, N, N] - scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities) - scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities) - - res = t_norm(scores_1, scores_2) - res, _ = torch.max(res, dim=1) - - return res - - -def query_pin(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - k: int, - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm) - scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 3:5], scoring_function=scoring_function) - - scores_2 = negation(scores_2) - - res = t_norm(scores_1, scores_2) - - return res - - -def query_pni_v1(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - k: int, - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - scores_1 = query_2p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 0:3], scoring_function=scoring_function, k=k, t_norm=t_norm) - scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 4:6], scoring_function=scoring_function) - - scores_1 = negation(scores_1) - - res = t_norm(scores_1, scores_2) - return res - - -def query_pni_v2(entity_embeddings: nn.Module, - predicate_embeddings: nn.Module, - queries: Tensor, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - k: int, - t_norm: Callable[[Tensor, Tensor], Tensor], - negation: Callable[[Tensor], Tensor]) -> Tensor: - - scores_1 = query_2pn(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 0:3], scoring_function=scoring_function, k=k, - t_norm=t_norm, negation=negation) - scores_2 = query_1p(entity_embeddings=entity_embeddings, predicate_embeddings=predicate_embeddings, - queries=queries[:, 4:6], scoring_function=scoring_function) - - res = t_norm(scores_1, scores_2) - return res - - def query_2u_dnf(entity_embeddings: nn.Module, predicate_embeddings: nn.Module, queries: Tensor, From 09353be1cd452ad625aa5e4bc7a494f93e928b55 Mon Sep 17 00:00:00 2001 From: pminervini Date: Fri, 4 Jun 2021 16:56:42 +0200 Subject: [PATCH 03/25] update --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ee1a5e2..f77716e 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,8 @@ # KGReasoning -This repo contains several algorithms for multi-hop reasoning on knowledge graphs, including the official Pytorch implementation of [Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs](https://arxiv.org/abs/2010.11465). +This repo contains several algorithms for multi-hop reasoning on knowledge graphs, including the official PyTorch implementation of [Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs](https://arxiv.org/abs/2010.11465) and a PyTorch implementation of [Complex Query Answering with Neural Link Predictors](https://arxiv.org/abs/2011.03459). **Models** +- [x] [CQD](https://arxiv.org/abs/2011.03459) - [x] [BetaE](https://arxiv.org/abs/2010.11465) - [x] [Query2box](https://arxiv.org/abs/2002.05969) - [x] [GQE](https://arxiv.org/abs/1806.01445) From 472268a35a5ce44cac01f10df58d01ff18d6f761 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:15:03 +0200 Subject: [PATCH 04/25] update --- cqd/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cqd/base.py b/cqd/base.py index 4152531..d36811b 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -68,8 +68,8 @@ def __init__(self, # XXX: get rid of this hack test_batch_size = 1000 - batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) - self.register_buffer('batch_entity_range', batch_entity_range) + self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) + # self.register_buffer('batch_entity_range', batch_entity_range) def split(self, lhs_emb: Tensor, From 6486b657a7c2e6efc70a6951d97790759aa0f47d Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:15:17 +0200 Subject: [PATCH 05/25] update --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index f77716e..5d3fa82 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # KGReasoning + This repo contains several algorithms for multi-hop reasoning on knowledge graphs, including the official PyTorch implementation of [Beta Embeddings for Multi-Hop Logical Reasoning in Knowledge Graphs](https://arxiv.org/abs/2010.11465) and a PyTorch implementation of [Complex Query Answering with Neural Link Predictors](https://arxiv.org/abs/2011.03459). **Models** From 888f506a7c5833f29905289583f911cf630266c4 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:15:29 +0200 Subject: [PATCH 06/25] update --- CQD.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 CQD.md diff --git a/CQD.md b/CQD.md new file mode 100644 index 0000000..f1187dc --- /dev/null +++ b/CQD.md @@ -0,0 +1,31 @@ +# Continuous Query Decomposition + +This repository contains the official implementation for our ICLR 2021 (Oral, Outstanding Paper Award) paper, [**Complex Query Answering with Neural Link Predictors**](https://openreview.net/forum?id=Mos9F9kDwkz). + +```bibtex +@inproceedings{ + arakelyan2021complex, + title={Complex Query Answering with Neural Link Predictors}, + author={Erik Arakelyan and Daniel Daza and Pasquale Minervini and Michael Cochez}, + booktitle={International Conference on Learning Representations}, + year={2021}, + url={https://openreview.net/forum?id=Mos9F9kDwkz} +} +``` + +In this work we present CQD, a method that reuses a pretrained link predictor to answer complex queries, by scoring atom predicates independently and aggregating the scores via t-norms and t-conorms. + +Our code is based on an implementation of ComplEx-N3 available [here](https://github.com/facebookresearch/kbc). + +### 1. Download the pre-trained models + +```bash +$ mkdir models/ +$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-betae.tar.gz +$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-q2b.tar.gz +$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-betae.tar.gz +$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-q2b.tar.gz +$ wget -c http://data.neuralnoise.com/cqd-models/nell-betae.tar.gz +$ wget -c http://data.neuralnoise.com/cqd-models/nell-q2b.tar.gz +$ for z in *.tar.gz; do tar xvfz $z; done +``` From 244eab67aef70ea3820e53b8eb6161bf4d230780 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:18:59 +0200 Subject: [PATCH 07/25] update --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index 8067c70..f8d7204 100755 --- a/main.py +++ b/main.py @@ -389,7 +389,7 @@ def main(args): filter(lambda p: p.requires_grad, model.parameters()), lr=current_learning_rate ) - warm_up_steps = args.max_steps // 2 + warm_up_steps = args.max_steps // 2 if args.warm_up_steps is None else args.warm_up_steps if args.checkpoint_path is not None: logging.info('Loading checkpoint %s...' % args.checkpoint_path) From c586f3f0ed66090fef16e50f5c65282ce6e3d51c Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:21:38 +0200 Subject: [PATCH 08/25] update --- main.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index f8d7204..108fbd1 100755 --- a/main.py +++ b/main.py @@ -226,19 +226,22 @@ def main(args): else: prefix = args.prefix - print ("overwritting args.save_path") - args.save_path = os.path.join(prefix, args.data_path.split('/')[-1], args.tasks, args.geo) - if args.geo in ['box']: - tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode) - elif args.geo in ['vec']: - tmp_str = "g-{}".format(args.gamma) - elif args.geo == 'beta': - tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode) - - if args.checkpoint_path is not None: - args.save_path = args.checkpoint_path - else: - args.save_path = os.path.join(args.save_path, tmp_str, cur_time) + if args.save_path is None: + print("overwritting args.save_path") + args.save_path = os.path.join(prefix, args.data_path.split('/')[-1], args.tasks, args.geo) + if args.geo in ['box']: + tmp_str = "g-{}-mode-{}".format(args.gamma, args.box_mode) + elif args.geo in ['vec']: + tmp_str = "g-{}".format(args.gamma) + elif args.geo == 'beta': + tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode) + elif args.geo == 'cqd': + tmp_str = "g-cqd" + + if args.checkpoint_path is not None: + args.save_path = args.checkpoint_path + else: + args.save_path = os.path.join(args.save_path, tmp_str, cur_time) if not os.path.exists(args.save_path): os.makedirs(args.save_path) From b79a62e592015a39eaa444f47e60f149a5654c8b Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:23:22 +0200 Subject: [PATCH 09/25] model.train_step -> KGReasoning.train_step (since it is static) --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 108fbd1..c129072 100755 --- a/main.py +++ b/main.py @@ -430,14 +430,14 @@ def main(args): if step == 2*args.max_steps//3: args.valid_steps *= 4 - log = model.train_step(model, optimizer, train_path_iterator, args, step) + log = KGReasoning.train_step(model, optimizer, train_path_iterator, args, step) for metric in log: writer.add_scalar('path_'+metric, log[metric], step) if train_other_iterator is not None: - log = model.train_step(model, optimizer, train_other_iterator, args, step) + log = KGReasoning.train_step(model, optimizer, train_other_iterator, args, step) for metric in log: writer.add_scalar('other_'+metric, log[metric], step) - log = model.train_step(model, optimizer, train_path_iterator, args, step) + log = KGReasoning.train_step(model, optimizer, train_path_iterator, args, step) training_logs.append(log) From 656a181ecbc1b796026f35f034232ab955a76e81 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:26:58 +0200 Subject: [PATCH 10/25] update --- models.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/models.py b/models.py index b761009..08ee797 100755 --- a/models.py +++ b/models.py @@ -5,20 +5,14 @@ from __future__ import print_function import logging -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from torch.utils.data import DataLoader -from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator -import random -import pickle -import math import collections -import itertools -import time + +from cqd import CQD + from tqdm import tqdm -import os def Identity(x): return x @@ -584,16 +578,24 @@ def train_step(model, optimizer, train_iterator, args, step): negative_sample = negative_sample.cuda() subsampling_weight = subsampling_weight.cuda() - positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) + if isinstance(model, CQD): + input_batch = batch_queries_dict[('e', ('r',))] + input_batch = torch.cat((input_batch, positive_sample.unsqueeze(1)), dim=1) + loss = model.loss(input_batch) + + positive_sample_loss = negative_sample_loss = loss + else: + positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) + + negative_score = F.logsigmoid(-negative_logit).mean(dim=1) + positive_score = F.logsigmoid(positive_logit).squeeze(dim=1) + positive_sample_loss = - (subsampling_weight * positive_score).sum() + negative_sample_loss = - (subsampling_weight * negative_score).sum() + positive_sample_loss /= subsampling_weight.sum() + negative_sample_loss /= subsampling_weight.sum() - negative_score = F.logsigmoid(-negative_logit).mean(dim=1) - positive_score = F.logsigmoid(positive_logit).squeeze(dim=1) - positive_sample_loss = - (subsampling_weight * positive_score).sum() - negative_sample_loss = - (subsampling_weight * negative_score).sum() - positive_sample_loss /= subsampling_weight.sum() - negative_sample_loss /= subsampling_weight.sum() + loss = (positive_sample_loss + negative_sample_loss)/2 - loss = (positive_sample_loss + negative_sample_loss)/2 loss.backward() optimizer.step() log = { From febdf3a5256d472962b4c236b10f107277870a10 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:35:09 +0200 Subject: [PATCH 11/25] PEP8 --- main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index c129072..52921dd 100755 --- a/main.py +++ b/main.py @@ -360,11 +360,11 @@ def main(args): hidden_dim=args.hidden_dim, gamma=args.gamma, geo=args.geo, - use_cuda = args.cuda, + use_cuda=args.cuda, box_mode=eval_tuple(args.box_mode), - beta_mode = eval_tuple(args.beta_mode), + beta_mode=eval_tuple(args.beta_mode), test_batch_size=args.test_batch_size, - query_name_dict = query_name_dict + query_name_dict=query_name_dict ) name_to_optimizer = { From 027d10eb69f0700374fafe13c570b7fe813203ec Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 17:42:10 +0200 Subject: [PATCH 12/25] update --- main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 52921dd..c28a9c7 100755 --- a/main.py +++ b/main.py @@ -352,7 +352,8 @@ def main(args): t_norm_name=args.cqd_t_norm, k=args.cqd_k, do_sigmoid=args.cqd_sigmoid_scores, - do_normalize=args.cqd_normalize_scores) + do_normalize=args.cqd_normalize_scores, + use_cuda=args.cuda) else: model = KGReasoning( nentity=nentity, From d57f95b0123e486aa3c143a46065c2ddea7044c1 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 18:30:23 +0200 Subject: [PATCH 13/25] update --- cqd/base.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/cqd/base.py b/cqd/base.py index d36811b..a85d280 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -43,7 +43,8 @@ def __init__(self, k: int = 5, query_name_dict: Optional[Dict] = None, do_sigmoid: bool = False, - do_normalize: bool = False): + do_normalize: bool = False, + use_cuda: bool = False): super(CQD, self).__init__() self.rank = rank @@ -67,10 +68,16 @@ def __init__(self, self.do_normalize = do_normalize # XXX: get rid of this hack - test_batch_size = 1000 - self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) + # test_batch_size = 1000 + # self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) # self.register_buffer('batch_entity_range', batch_entity_range) + self.use_cuda = use_cuda + + self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) + if self.use_cuda is True: + self.batch_entity_range = self.batch_entity_range.cuda() + def split(self, lhs_emb: Tensor, rel_emb: Tensor, From 03a2450b68ca3290ff2c3bf6c491124328461b86 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 5 Jun 2021 18:51:43 +0200 Subject: [PATCH 14/25] cleanup --- cqd/base.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cqd/base.py b/cqd/base.py index a85d280..4dfefb9 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -67,13 +67,7 @@ def __init__(self, self.do_sigmoid = do_sigmoid self.do_normalize = do_normalize - # XXX: get rid of this hack - # test_batch_size = 1000 - # self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) - # self.register_buffer('batch_entity_range', batch_entity_range) - self.use_cuda = use_cuda - self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) if self.use_cuda is True: self.batch_entity_range = self.batch_entity_range.cuda() From d52e12805015f9df08e5317038729475d9d904af Mon Sep 17 00:00:00 2001 From: pminervini Date: Sun, 6 Jun 2021 14:48:51 +0200 Subject: [PATCH 15/25] better defaults --- cqd/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cqd/base.py b/cqd/base.py index 4dfefb9..91b01e7 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -38,7 +38,7 @@ def __init__(self, init_size: float = 1e-3, reg_weight: float = 1e-2, test_batch_size: int = 1, - method: str = 'beam', + method: str = 'discrete', t_norm_name: str = 'prod', k: int = 5, query_name_dict: Optional[Dict] = None, From 4eeb63684d72f7060c64ff628349ce4f3791e555 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sun, 6 Jun 2021 14:49:10 +0200 Subject: [PATCH 16/25] continuous optimisation requires gradients --- models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/models.py b/models.py index 08ee797..974b6ff 100755 --- a/models.py +++ b/models.py @@ -613,7 +613,10 @@ def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_na total_steps = len(test_dataloader) logs = collections.defaultdict(list) - with torch.no_grad(): + require_grad = isinstance(model, CQD) and model.method == 'continuous' + + # with torch.no_grad(): + with torch.set_grad_enabled(require_grad): for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen): batch_queries_dict = collections.defaultdict(list) batch_idxs_dict = collections.defaultdict(list) From 5a4590c5eb2bde4030038a0b8f74e2368b8cd54f Mon Sep 17 00:00:00 2001 From: pminervini Date: Sun, 6 Jun 2021 15:08:20 +0200 Subject: [PATCH 17/25] require -> requires --- models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models.py b/models.py index 974b6ff..07d4857 100755 --- a/models.py +++ b/models.py @@ -613,10 +613,10 @@ def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_na total_steps = len(test_dataloader) logs = collections.defaultdict(list) - require_grad = isinstance(model, CQD) and model.method == 'continuous' + requires_grad = isinstance(model, CQD) and model.method == 'continuous' # with torch.no_grad(): - with torch.set_grad_enabled(require_grad): + with torch.set_grad_enabled(requires_grad): for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen): batch_queries_dict = collections.defaultdict(list) batch_idxs_dict = collections.defaultdict(list) From e362fbae2d7ea918253a9e4bc4d9b3e9c73ca048 Mon Sep 17 00:00:00 2001 From: pminervini Date: Wed, 9 Jun 2021 23:33:55 +0200 Subject: [PATCH 18/25] clearer documentation --- CQD.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/CQD.md b/CQD.md index f1187dc..30512c6 100644 --- a/CQD.md +++ b/CQD.md @@ -21,11 +21,6 @@ Our code is based on an implementation of ComplEx-N3 available [here](https://gi ```bash $ mkdir models/ -$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-betae.tar.gz -$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-q2b.tar.gz -$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-betae.tar.gz -$ wget -c http://data.neuralnoise.com/cqd-models/fb15k-237-q2b.tar.gz -$ wget -c http://data.neuralnoise.com/cqd-models/nell-betae.tar.gz -$ wget -c http://data.neuralnoise.com/cqd-models/nell-q2b.tar.gz -$ for z in *.tar.gz; do tar xvfz $z; done +$ for i in "fb15k" "fb15k-237" "nell"; do for j in "betae" "q2b"; do wget -c http://data.neuralnoise.com/kgreasoning-cqd/$i-$j.tar.gz; done; done +$ for i in *.tar.gz; do tar xvfz $i; done ``` From d7ed58e2dc3337b8d0fb55667c9fa10f7d0b263a Mon Sep 17 00:00:00 2001 From: pminervini Date: Thu, 17 Jun 2021 14:50:33 +0200 Subject: [PATCH 19/25] shorter and clearer code --- models.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/models.py b/models.py index 07d4857..4a8fcb7 100755 --- a/models.py +++ b/models.py @@ -639,18 +639,11 @@ def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_na if len(argsort) == args.test_batch_size: # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities else: # otherwise, create a new torch Tensor for batch_entity_range + scatter_src = torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1) if args.cuda: - ranking = ranking.scatter_(1, - argsort, - torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], - 1).cuda() - ) # achieve the ranking of all entities - else: - ranking = ranking.scatter_(1, - argsort, - torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], - 1) - ) # achieve the ranking of all entities + scatter_src = scatter_src.cuda() + # achieve the ranking of all entities + ranking = ranking.scatter_(1, argsort, scatter_src) for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)): hard_answer = hard_answers[query] easy_answer = easy_answers[query] From 6ab0bb06ad972b3aeab9aa14ba18eeb52f72bf70 Mon Sep 17 00:00:00 2001 From: pminervini Date: Thu, 17 Jun 2021 20:10:22 +0200 Subject: [PATCH 20/25] removing old unused code --- cqd/base.py | 27 +------ cqd/util.py | 208 ---------------------------------------------------- main.py | 2 +- 3 files changed, 3 insertions(+), 234 deletions(-) diff --git a/cqd/base.py b/cqd/base.py index 91b01e7..12e22c1 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -9,7 +9,7 @@ from torch import optim, Tensor import math -from cqd.util import query_to_atoms, create_instructions, top_k_selection +from cqd.util import query_to_atoms import cqd.discrete as d2 from typing import Tuple, List, Optional, Dict @@ -217,7 +217,7 @@ def forward(self, h_emb_constants = self.embeddings[0](head) r_emb = self.embeddings[1](rel) - if 'co' in self.method: + if 'continuous' in self.method: h_emb = h_emb_constants if num_variables > 1: # var embedding for ID 0 is unused for ease of implementation @@ -275,29 +275,6 @@ def forward(self, scores = torch.cat(all_scores, dim=0) - elif 'continuous' in self.method: - graph_type = self.query_name_dict[query_structure] - - chain_instructions = create_instructions(atoms[0]) - chains = [] - - for atom in range(len(atoms[0])): - part = atoms[:, atom, :] - chain = self.get_full_embeddings(part) - chains.append(chain) - - scores = top_k_selection(chains, - chain_instructions, - graph_type, - # score_o takes lhs, rel, rhs - scoring_function=lambda rel_, lhs_, rhs_: self.score_o(lhs_, rel_, rhs_)[0], - forward_emb=lambda lhs_, rel_: self.score_o(lhs_, rel_, self.embeddings[0].weight)[0], - entity_embeddings=self.embeddings[0], - candidates=self.k, - t_norm=self.t_norm_name, - batch_size=1, - scores_normalize='default') - elif 'discrete' in self.method: graph_type = self.query_name_dict[query_structure] diff --git a/cqd/util.py b/cqd/util.py index dee0d78..a6bcd4c 100644 --- a/cqd/util.py +++ b/cqd/util.py @@ -173,211 +173,3 @@ def get_best_candidates(rel: Tensor, z_emb = entity_embeddings(z_indices) return z_scores, z_emb - - -def top_k_selection(chains, - chain_instructions, - graph_type, - scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], - forward_emb: Callable[[Tensor, Tensor], Tensor], - entity_embeddings: Callable[[Tensor], Tensor], - candidates: int = 5, - t_norm: str = 'min', - batch_size: int = 1, - scores_normalize: str = 'default'): - res = None - - if 'disj' in graph_type: - objective = t_conorm_fn - else: - objective = t_norm_fn - - nb_queries, embedding_size = chains[0][0].shape[0], chains[0][0].shape[1] - - scores = None - - batches = make_batches(nb_queries, batch_size) - - for batch in tqdm(batches): - - nb_branches = 1 - nb_ent = 0 - batch_scores = None - candidate_cache = {} - - batch_size = batch[1] - batch[0] - dnf_flag = False - if 'disj' in graph_type: - dnf_flag = True - - for inst_ind, inst in enumerate(chain_instructions): - with torch.no_grad(): - if 'hop' in inst: - - ind_1 = int(inst.split("_")[-2]) - ind_2 = int(inst.split("_")[-1]) - - indices = [ind_1, ind_2] - - if objective == t_conorm_fn and dnf_flag: - objective = t_norm_fn - - last_hop = False - for hop_num, ind in enumerate(indices): - last_step = (inst_ind == len(chain_instructions) - 1) and last_hop - - lhs, rel, rhs = chains[ind] - - if lhs is not None: - lhs = lhs[batch[0]:batch[1]] - else: - # print("MTA BRAT") - batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"] - lhs = lhs_3d.view(-1, embedding_size) - - rel = rel[batch[0]:batch[1]] - rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1) - rel = rel.view(-1, embedding_size) - - if f"rhs_{ind}" not in candidate_cache: - - # print("STTEEE MTA") - z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step) - - # [Num_queries * Candidates^K] - z_scores_1d = z_scores.view(-1) - if 'disj' in graph_type or scores_normalize: - z_scores_1d = torch.sigmoid(z_scores_1d) - - # B * S - nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1] - nb_branches = nb_sources // batch_size - if not last_step: - batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm) - else: - nb_ent = rhs_3d.shape[1] - batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm) - - candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d) - - if not last_hop: - candidate_cache[f"lhs_{indices[hop_num + 1]}"] = (batch_scores, rhs_3d) - - else: - batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"] - candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d) - last_hop = True - continue - - last_hop = True - - elif 'inter' in inst: - ind_1 = int(inst.split("_")[-2]) - ind_2 = int(inst.split("_")[-1]) - - indices = [ind_1, ind_2] - - if objective == t_norm_fn and dnf_flag: - objective = t_conorm_fn - - if len(inst.split("_")) > 3: - ind_1 = int(inst.split("_")[-3]) - ind_2 = int(inst.split("_")[-2]) - ind_3 = int(inst.split("_")[-1]) - - indices = [ind_1, ind_2, ind_3] - - for intersection_num, ind in enumerate(indices): - last_step = (inst_ind == len(chain_instructions) - 1) # and ind == indices[0] - - lhs, rel, rhs = chains[ind] - - if lhs is not None: - lhs = lhs[batch[0]:batch[1]] - lhs = lhs.view(-1, 1, embedding_size).repeat(1, nb_branches, 1) - lhs = lhs.view(-1, embedding_size) - - else: - batch_scores, lhs_3d = candidate_cache[f"lhs_{ind}"] - lhs = lhs_3d.view(-1, embedding_size) - nb_sources = lhs_3d.shape[0] * lhs_3d.shape[1] - nb_branches = nb_sources // batch_size - - rel = rel[batch[0]:batch[1]] - rel = rel.view(-1, 1, embedding_size).repeat(1, nb_branches, 1) - rel = rel.view(-1, embedding_size) - - if intersection_num > 0 and 'disj' in graph_type: - batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"] - rhs = rhs_3d.view(-1, embedding_size) - z_scores = scoring_function(rel, lhs, rhs) - - z_scores_1d = z_scores.view(-1) - if 'disj' in graph_type or scores_normalize: - z_scores_1d = torch.sigmoid(z_scores_1d) - - batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores, t_norm) - - continue - - if f"rhs_{ind}" not in candidate_cache or last_step: - z_scores, rhs_3d = get_best_candidates(rel, lhs, forward_emb, entity_embeddings, candidates, last_step) - - # [B * Candidates^K] or [B, S-1, N] - z_scores_1d = z_scores.view(-1) - if 'disj' in graph_type or scores_normalize: - z_scores_1d = torch.sigmoid(z_scores_1d) - - if not last_step: - nb_sources = rhs_3d.shape[0] * rhs_3d.shape[1] - nb_branches = nb_sources // batch_size - - if not last_step: - batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, candidates).view(-1), t_norm) - else: - if ind == indices[0]: - nb_ent = rhs_3d.shape[1] - else: - nb_ent = 1 - - batch_scores = z_scores_1d if batch_scores is None else objective(z_scores_1d, batch_scores.view(-1, 1).repeat(1, nb_ent).view(-1), t_norm) - nb_ent = rhs_3d.shape[1] - - candidate_cache[f"rhs_{ind}"] = (batch_scores, rhs_3d) - - if ind == indices[0] and 'disj' in graph_type: - count = len(indices) - 1 - iterator = 1 - while count > 0: - candidate_cache[f"rhs_{indices[intersection_num + iterator]}"] = ( - batch_scores, rhs_3d) - iterator += 1 - count -= 1 - - if ind == indices[-1]: - candidate_cache[f"lhs_{ind + 1}"] = (batch_scores, rhs_3d) - else: - batch_scores, rhs_3d = candidate_cache[f"rhs_{ind}"] - candidate_cache[f"rhs_{ind + 1}"] = (batch_scores, rhs_3d) - - last_hop = True - del lhs, rel - continue - - del lhs, rel, rhs, rhs_3d, z_scores_1d, z_scores - - if batch_scores is not None: - # [B * entites * S ] - # S == K**(V-1) - scores_2d = batch_scores.view(batch_size, -1, nb_ent) - res, _ = torch.max(scores_2d, dim=1) - scores = res if scores is None else torch.cat([scores, res]) - - del batch_scores, scores_2d, res, candidate_cache - - else: - assert False, "Batch Scores are empty: an error went uncaught." - - res = scores - - return res diff --git a/main.py b/main.py index c28a9c7..6a3c5d1 100755 --- a/main.py +++ b/main.py @@ -82,7 +82,7 @@ def parse_args(args=None): parser.add_argument('--reg_weight', default=1e-3, type=float) parser.add_argument('--optimizer', choices=['adam', 'adagrad'], default='adam') - parser.add_argument('--cqd-type', '--cqd', default='co', type=str, choices=['continuous', 'discrete']) + parser.add_argument('--cqd-type', '--cqd', default='discrete', type=str, choices=['continuous', 'discrete']) parser.add_argument('--cqd-t-norm', default=CQD.PROD_NORM, type=str, choices=CQD.NORMS) parser.add_argument('--cqd-k', default=5, type=int) parser.add_argument('--cqd-sigmoid-scores', '--cqd-sigmoid', action='store_true', default=False) From fce04172e460ddab7bfad6bfb2cf159dd5fdcbbf Mon Sep 17 00:00:00 2001 From: pminervini Date: Sun, 27 Jun 2021 12:20:09 +0200 Subject: [PATCH 21/25] more efficient answering of ip and up queries --- cqd/base.py | 4 ++-- cqd/discrete.py | 50 +++++++++++++++++++++++++++++++++---------------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/cqd/base.py b/cqd/base.py index 12e22c1..6f2ac23 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -342,7 +342,7 @@ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor: predicate_embeddings=self.embeddings[1], queries=queries, scoring_function=scoring_function, - t_norm=t_norm) + k=self.k, t_norm=t_norm) elif graph_type == "2u-DNF": scores = d2.query_2u_dnf(entity_embeddings=self.embeddings[0], predicate_embeddings=self.embeddings[1], @@ -354,7 +354,7 @@ def scoring_function(rel_: Tensor, lhs_: Tensor, rhs_: Tensor) -> Tensor: predicate_embeddings=self.embeddings[1], queries=queries, scoring_function=scoring_function, - t_norm=t_norm, t_conorm=t_conorm) + k=self.k, t_norm=t_norm, t_conorm=t_conorm) else: raise ValueError(f'Unknown query type: {graph_type}') diff --git a/cqd/discrete.py b/cqd/discrete.py index c4f9c88..749ac3c 100644 --- a/cqd/discrete.py +++ b/cqd/discrete.py @@ -206,6 +206,7 @@ def query_ip(entity_embeddings: nn.Module, predicate_embeddings: nn.Module, queries: Tensor, scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, t_norm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: # [B, N] @@ -222,18 +223,26 @@ def query_ip(entity_embeddings: nn.Module, e_emb = entity_embeddings.weight nb_entities = e_emb.shape[0] - # [B * N, E] - s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size) + k_ = min(k, nb_entities) + + # [B, K], [B, K] + scores_1_k, scores_1_k_indices = torch.topk(scores_1, k=k_, dim=1) + + # [B, K, E] + scores_1_k_emb = entity_embeddings(scores_1_k_indices) + + # [B * K, E] + scores_1_k_emb_2d = scores_1_k_emb.reshape(batch_size * k_, emb_size) - # [B * N, N] - scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None, + # [B * K, N] + scores_2, _ = score_candidates(s_emb=scores_1_k_emb_2d, p_emb=p_emb, candidates_emb=e_emb, k=None, entity_embeddings=entity_embeddings, scoring_function=scoring_function) - # [B, N, N] - scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities) - scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities) + # [B * K, N] + scores_1_k = scores_1_k.reshape(batch_size, k_, 1).repeat(1, 1, nb_entities) + scores_2 = scores_2.reshape(batch_size, k_, nb_entities) - res = t_norm(scores_1, scores_2) + res = t_norm(scores_1_k, scores_2) res, _ = torch.max(res, dim=1) return res @@ -276,6 +285,7 @@ def query_up_dnf(entity_embeddings: nn.Module, predicate_embeddings: nn.Module, queries: Tensor, scoring_function: Callable[[Tensor, Tensor, Tensor], Tensor], + k: int, t_norm: Callable[[Tensor, Tensor], Tensor], t_conorm: Callable[[Tensor, Tensor], Tensor]) -> Tensor: # [B, N] @@ -292,18 +302,26 @@ def query_up_dnf(entity_embeddings: nn.Module, e_emb = entity_embeddings.weight nb_entities = e_emb.shape[0] - # [B * N, E] - s_emb = e_emb.reshape(1, nb_entities, emb_size).repeat(batch_size, 1, 1).reshape(-1, emb_size) + k_ = min(k, nb_entities) - # [B * N, N] - scores_2, _ = score_candidates(s_emb=s_emb, p_emb=p_emb, candidates_emb=e_emb, k=None, + # [B, K], [B, K] + scores_1_k, scores_1_k_indices = torch.topk(scores_1, k=k_, dim=1) + + # [B, K, E] + scores_1_k_emb = entity_embeddings(scores_1_k_indices) + + # [B * K, E] + scores_1_k_emb_2d = scores_1_k_emb.reshape(batch_size * k_, emb_size) + + # [B * K, N] + scores_2, _ = score_candidates(s_emb=scores_1_k_emb_2d, p_emb=p_emb, candidates_emb=e_emb, k=None, entity_embeddings=entity_embeddings, scoring_function=scoring_function) - # [B, N, N] - scores_1 = scores_1.reshape(batch_size, nb_entities, 1).repeat(1, 1, nb_entities) - scores_2 = scores_2.reshape(batch_size, nb_entities, nb_entities) + # [B * K, N] + scores_1_k = scores_1_k.reshape(batch_size, k_, 1).repeat(1, 1, nb_entities) + scores_2 = scores_2.reshape(batch_size, k_, nb_entities) - res = t_norm(scores_1, scores_2) + res = t_norm(scores_1_k, scores_2) res, _ = torch.max(res, dim=1) return res From 46e953466b3aa588d21ac7578050d15f27bcf323 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 26 Mar 2022 19:24:50 +0100 Subject: [PATCH 22/25] update --- CQD.md | 259 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 258 insertions(+), 1 deletion(-) diff --git a/CQD.md b/CQD.md index 30512c6..afa6d34 100644 --- a/CQD.md +++ b/CQD.md @@ -17,10 +17,267 @@ In this work we present CQD, a method that reuses a pretrained link predictor to Our code is based on an implementation of ComplEx-N3 available [here](https://github.com/facebookresearch/kbc). -### 1. Download the pre-trained models +## 1. Download the pre-trained models + +To download and decompress the pre-trained models, execute the folloing commands: ```bash $ mkdir models/ $ for i in "fb15k" "fb15k-237" "nell"; do for j in "betae" "q2b"; do wget -c http://data.neuralnoise.com/kgreasoning-cqd/$i-$j.tar.gz; done; done $ for i in *.tar.gz; do tar xvfz $i; done ``` + +## 2. Answer the complex queries + +One catch is that the query answering process in CQD depends on some hyperparameters, i.e. the "beam size" `k`, the t-norm to use (e.g. `min` or `prod`), and the normalisation function that maps scores to the `[0, 1]` interval; in our experiments, we select these on the validation set. Here are the commands to execute to evaluate CQD on each type of queries: + +### 2.1 -- FB15k + +1p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete +[..] +Test 1p MRR at step 99999: 0.891426 +Test 1p HITS1 at step 99999: 0.857939 +Test 1p HITS3 at step 99999: 0.915589 +[..] +``` + +2p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --valid_steps 20 --tasks 2p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda +[..] +Test 2p HITS3 at step 99999: 0.791121 +[..] +``` + +3p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --valid_steps 20 --tasks 3p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 4 --cuda +[..] +Test 3p HITS3 at step 99999: 0.459223 +[..] +``` + +2i queries: + +```bash +PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda +[..] +Test 2i HITS3 at step 99999: 0.788954 +[..] +``` + +3i queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda +[..] +Test 3i HITS3 at step 99999: 0.837378 +[..] +``` + +ip queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks ip --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda +[..] +Test ip HITS3 at step 99999: 0.649221 +[..] +``` + +pi queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks pi --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda +[..] +Test pi HITS3 at step 99999: 0.681604 +[..] +``` + +2u queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2u --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm min --cqd-normalize --cqd-k 16 --cuda +[..] +Test 2u-DNF HITS3 at step 99999: 0.853601 +[..] +``` + +up queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks up --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-q2b --cqd discrete --cqd-t-norm min --cqd-sigmoid --cqd-k 16 --cuda +[..] +Test up-DNF HITS3 at step 99999: 0.709496 +[..] +``` + +### 2.2 -- FB15k-237 + +1p queries: + +```bash +$ PYTHONPATH=. python3 PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cuda +[..] +Test 1p HITS3 at step 99999: 0.511910 +[..] +``` + +2p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda +[..] +Test 2p HITS3 at step 99999: 0.286640 +[..] +``` + +3p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 4 --cuda +[..] +Test 3p HITS3 at step 99999: 0.199947 +[..] +``` + +2i queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda +[..] +Test 2i HITS3 at step 99999: 0.376709 +[..] +``` + +3i queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3i --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda +[..] +Test 3i HITS3 at step 99999: 0.488725 +[..] +``` + +ip queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks ip --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda +[..] +Test ip HITS3 at step 99999: 0.182000 +[..] +``` + +pi queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks pi --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 64 --cuda +[..] +Test pi HITS3 at step 99999: 0.267872 +[..] +``` + +2u queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2u --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm min --cqd-normalize --cqd-k 16 --cuda +[..] +Test 2u-DNF HITS3 at step 99999: 0.323751 +[..] +``` + +up queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks up --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 16 --cuda +[..] +Test up-DNF HITS3 at step 99999: 0.225360 +[..] +``` + +### 2.2 -- NELL 995 + +1p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cuda +[..] +Test 1p HITS3 at step 99999: 0.663197 +[..] +``` + +2p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2p --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-k 64 --cuda +[..] +Test 2p HITS3 at step 99999: 0.351218 +[..] +``` + +3p queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --valid_steps 20 --tasks 3p --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-sigmoid --cqd-k 2 --cuda +[..] +Test 3p HITS3 at step 99999: 0.263724 +[..] +``` + +2i queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2i --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda +[..] +Test 2i HITS3 at step 99999: 0.422821 +[..] +``` + +3i queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 3i --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 16 --cuda +[..] +Test 3i HITS3 at step 99999: 0.538633 +[..] +``` + +ip queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks ip --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-k 16 --cuda +[..] +Test ip HITS3 at step 99999: 0.234066 +[..] +``` + +pi queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks pi --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm prod --cqd-normalize --cqd-k 64 --cuda +[..] +Test pi HITS3 at step 99999: 0.315222 +[..] +``` + +2u queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 2u --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm min --cqd-normalize --cqd-k 16 --cuda +[..] +Test 2u-DNF HITS3 at step 99999: 0.541287 +[..] +``` + +up queries: + +```bash +$ PYTHONPATH=. python3 main.py --do_valid --do_test --data_path data/NELL-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks up --print_on_screen --test_batch_size 1 --checkpoint_path models/nell-q2b --cqd discrete --cqd-t-norm min --cqd-sigmoid --cqd-k 16 --cuda +[..] +Test up-DNF HITS3 at step 99999: 0.290282 +[..] +``` From fba0ff4121aa225e2929793f434690846cbae22d Mon Sep 17 00:00:00 2001 From: pminervini Date: Sat, 26 Mar 2022 22:07:15 +0100 Subject: [PATCH 23/25] Fixing issue #2 --- cqd/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cqd/base.py b/cqd/base.py index 6f2ac23..2120395 100644 --- a/cqd/base.py +++ b/cqd/base.py @@ -56,7 +56,7 @@ def __init__(self, self.query_name_dict = query_name_dict sizes = (nentity, nrelation) - self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=True) for s in sizes[:2]]) + self.embeddings = nn.ModuleList([nn.Embedding(s, 2 * rank, sparse=False) for s in sizes[:2]]) self.embeddings[0].weight.data *= init_size self.embeddings[1].weight.data *= init_size From 71ec9a52b550407cae9153a293fcf2f775173f88 Mon Sep 17 00:00:00 2001 From: pminervini Date: Sun, 27 Mar 2022 22:03:34 +0200 Subject: [PATCH 24/25] update --- CQD.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CQD.md b/CQD.md index afa6d34..3abb9bd 100644 --- a/CQD.md +++ b/CQD.md @@ -19,7 +19,7 @@ Our code is based on an implementation of ComplEx-N3 available [here](https://gi ## 1. Download the pre-trained models -To download and decompress the pre-trained models, execute the folloing commands: +To download and decompress the pre-trained models, execute the following commands: ```bash $ mkdir models/ @@ -27,6 +27,14 @@ $ for i in "fb15k" "fb15k-237" "nell"; do for j in "betae" "q2b"; do wget -c htt $ for i in *.tar.gz; do tar xvfz $i; done ``` +In case you need to re-train the models from scratch, use the following command lines: + +```bash +PYTHONPATH=. python3 main.py --do_train --do_valid --do_test --data_path data/FB15k-237-q2b -n 1 -b 2000 -d 1000 -lr 0.1 --warm_up_steps 100000000 --max_steps 100000 --cpu_num 0 --geo cqd --valid_steps 500 --tasks 1p --print_on_screen --test_batch_size 1000 --optimizer adagrad --reg_weight 0.1 --log_steps 500 --cuda --use-qa-iterator +PYTHONPATH=. python3 main.py --do_train --do_valid --do_test --data_path data/FB15k-q2b -n 1 -b 5000 -d 1000 -lr 0.1 --warm_up_steps 100000000 --max_steps 100000 --cpu_num 0 --geo cqd --valid_steps 500 --tasks 1p --print_on_screen --test_batch_size 1000 --optimizer adagrad --reg_weight 0.01 --log_steps 500 --cuda --use-qa-iterator +PYTHONPATH=. python3 main.py --do_train --do_valid --do_test --data_path data/NELL-q2b -n 1 -b 2000 -d 1000 -lr 0.1 --warm_up_steps 100000000 --max_steps 100000 --cpu_num 0 --geo cqd --valid_steps 500 --tasks 1p --print_on_screen --test_batch_size 1000 --optimizer adagrad --reg_weight 0.1 --log_steps 500 --cuda --use-qa-iterator +``` + ## 2. Answer the complex queries One catch is that the query answering process in CQD depends on some hyperparameters, i.e. the "beam size" `k`, the t-norm to use (e.g. `min` or `prod`), and the normalisation function that maps scores to the `[0, 1]` interval; in our experiments, we select these on the validation set. Here are the commands to execute to evaluate CQD on each type of queries: From 81060a2af7153b0b910704397cff8ea7e71b713a Mon Sep 17 00:00:00 2001 From: pminervini Date: Sun, 27 Mar 2022 22:05:15 +0200 Subject: [PATCH 25/25] update --- CQD.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CQD.md b/CQD.md index 3abb9bd..7177f07 100644 --- a/CQD.md +++ b/CQD.md @@ -129,7 +129,7 @@ Test up-DNF HITS3 at step 99999: 0.709496 1p queries: ```bash -$ PYTHONPATH=. python3 PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cuda +$ PYTHONPATH=. python3 main.py --do_test --data_path data/FB15k-237-q2b -n 1 -b 1000 -d 1000 --cpu_num 0 --geo cqd --tasks 1p --print_on_screen --test_batch_size 1 --checkpoint_path models/fb15k-237-q2b --cqd discrete --cuda [..] Test 1p HITS3 at step 99999: 0.511910 [..]