From de579ba102bf751826c59873a10dd76f9eb4c53c Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:47:01 -0500 Subject: [PATCH 1/9] Update for Python 3 Includes changes to cPickle (change to pickle), print (add parentheses), .iteritems (change to .items() ), and xrange (change to range) --- aggregators.py | 200 ++++++++++++++ data_utils.py | 110 ++++++++ encoders.py | 146 ++++++++++ graph.py | 695 +++++++++++++++++++++++++++++++++++++++++++++++ model.py | 189 +++++++++++++ train.py | 81 ++++++ train_helpers.py | 107 ++++++++ utils.py | 167 ++++++++++++ 8 files changed, 1695 insertions(+) create mode 100644 aggregators.py create mode 100644 data_utils.py create mode 100644 encoders.py create mode 100644 graph.py create mode 100644 model.py create mode 100644 train.py create mode 100644 train_helpers.py create mode 100644 utils.py diff --git a/aggregators.py b/aggregators.py new file mode 100644 index 0000000..0052886 --- /dev/null +++ b/aggregators.py @@ -0,0 +1,200 @@ +import torch +import torch.nn as nn +import itertools +from torch.nn import init +from torch.autograd import Variable +import torch.nn.functional as F + +import random +import math +import numpy as np + +""" +Set of modules for aggregating embeddings of neighbors. +These modules take as input embeddings of neighbors. +""" + +class MeanAggregator(nn.Module): + """ + Aggregates a node's embeddings using mean of neighbors' embeddings + """ + def __init__(self, features, cuda=False): + """ + Initializes the aggregator for a specific graph. + + features -- function mapping (node_list, features, offset) to feature values + see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + """ + + super(MeanAggregator, self).__init__() + + self.features = features + self.cuda = cuda + + def forward(self, to_neighs, rel, keep_prob=0.5, max_keep=10): + """ + Aggregates embeddings for a batch of nodes. + keep_prob and max_keep are the parameters for edge/neighbour dropout. + + to_neighs -- list of neighbors of nodes + keep_prob -- probability of keeping a neighbor + max_keep -- maximum number of neighbors kept per node + """ + + # Local pointers to functions (speed hack) + _int = int + _set = set + _min = min + _len = len + _ceil = math.ceil + _sample = random.sample + samp_neighs = [_set(_sample(to_neigh, + _min(_int(_ceil(_len(to_neigh)*keep_prob)), max_keep) + )) for to_neigh in to_neighs] + unique_nodes_list = list(set.union(*samp_neighs)) + unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)} + mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) + column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] + row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))] + mask[row_indices, column_indices] = 1 + if self.cuda: + mask = mask.cuda() + num_neigh = mask.sum(1, keepdim=True) + mask = mask.div(num_neigh) + embed_matrix = self.features(unique_nodes_list, rel[-1]) + if len(embed_matrix.size()) == 1: + embed_matrix = embed_matrix.unsqueeze(dim=0) + to_feats = mask.mm(embed_matrix) + return to_feats + +class FastMeanAggregator(nn.Module): + """ + Aggregates a node's embeddings using mean of neighbors' embeddings + """ + def __init__(self, features, cuda=False): + """ + Initializes the aggregator for a specific graph. + + features -- function mapping (node_list, features, offset) to feature values + see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + """ + + super(FastMeanAggregator, self).__init__() + + self.features = features + self.cuda = cuda + + def forward(self, to_neighs, rel, keep_prob=None, max_keep=25): + """ + Aggregates embeddings for a batch of nodes. + keep_prob and max_keep are the parameters for edge/neighbour dropout. + + to_neighs -- list of neighbors of nodes + keep_prob -- probability of keeping a neighbor + max_keep -- maximum number of neighbors kept per node + """ + _random = random.random + _int = int + _len = len + samp_neighs = [to_neigh[_int(_random()*_len(to_neigh))] for i in itertools.repeat(None, max_keep) + for to_neigh in to_neighs] + embed_matrix = self.features(samp_neighs, rel[-1]) + to_feats = embed_matrix.view(max_keep, len(to_neighs), embed_matrix.size()[1]) + return to_feats.mean(dim=0) + +class PoolAggregator(nn.Module): + """ + Aggregates a node's embeddings using mean pooling of neighbors' embeddings + """ + def __init__(self, features, feature_dims, cuda=False): + """ + Initializes the aggregator for a specific graph. + + features -- function mapping (node_list, features, offset) to feature values + see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + """ + + super(PoolAggregator, self).__init__() + + self.features = features + self.feat_dims = feature_dims + self.pool_matrix = {} + for mode, feat_dim in self.feat_dims.items(): + self.pool_matrix[mode] = nn.Parameter(torch.FloatTensor(feat_dim, feat_dim)) + init.xavier_uniform(self.pool_matrix[mode]) + self.register_parameter(mode+"_pool", self.pool_matrix[mode]) + self.cuda = cuda + + def forward(self, to_neighs, rel, keep_prob=0.5, max_keep=10): + """ + Aggregates embeddings for a batch of nodes. + keep_prob and max_keep are the parameters for edge/neighbour dropout. + + to_neighs -- list of neighbors of nodes + keep_prob -- probability of keeping a neighbor + max_keep -- maximum number of neighbors kept per node + """ + _int = int + _set = set + _min = min + _len = len + _ceil = math.ceil + _sample = random.sample + samp_neighs = [_set(_sample(to_neigh, + _min(_int(_ceil(_len(to_neigh)*keep_prob)), max_keep) + )) for to_neigh in to_neighs] + unique_nodes_list = list(set.union(*samp_neighs)) + unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)} + mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) + column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] + row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))] + mask[row_indices, column_indices] = 1 + mode = rel[0] + if self.cuda: + mask = mask.cuda() + embed_matrix = self.features(unique_nodes, rel[-1]).mm(self.pool_matrix[mode]) + to_feats = F.relu(mask.mm(embed_matrix)) + return to_feats + +class FastPoolAggregator(nn.Module): + """ + Aggregates a node's embeddings using mean pooling of neighbors' embeddings + """ + def __init__(self, features, feature_dims, + cuda=False): + """ + Initializes the aggregator for a specific graph. + + features -- function mapping (node_list, features, offset) to feature values + see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + """ + + super(FastPoolAggregator, self).__init__() + + self.features = features + self.feat_dims = feature_dims + self.pool_matrix = {} + for mode, feat_dim in self.feat_dims.items(): + self.pool_matrix[mode] = nn.Parameter(torch.FloatTensor(feat_dim, feat_dim)) + init.xavier_uniform(self.pool_matrix[mode]) + self.register_parameter(mode+"_pool", self.pool_matrix[mode]) + self.cuda = cuda + + def forward(self, to_neighs, rel, keep_prob=0.5, max_keep=10): + """ + Aggregates embeddings for a batch of nodes. + keep_prob and max_keep are the parameters for edge/neighbour dropout. + + to_neighs -- list of neighbors of nodes + keep_prob -- probability of keeping a neighbor + max_keep -- maximum number of neighbors kept per node + """ + _random = random.random + _int = int + _len = len + samp_neighs = [to_neigh[_int(_random()*_len(to_neigh))] for i in itertools.repeat(None, max_keep) + for to_neigh in to_neighs] + mode = rel[0] + embed_matrix = self.features(samp_neighs, rel[-1]).mm(self.pool_matrix[mode]) + to_feats = embed_matrix.view(max_keep, len(to_neighs), embed_matrix.size()[1]) + return to_feats.mean(dim=0) diff --git a/data_utils.py b/data_utils.py new file mode 100644 index 0000000..8ef54ad --- /dev/null +++ b/data_utils.py @@ -0,0 +1,110 @@ +#import cPickle as pickle +import pickle +import torch +from collections import OrderedDict, defaultdict +from multiprocessing import Process +import random +import json +from netquery.data_utils import parallel_sample, load_queries_by_type, sample_clean_test + +from netquery.graph import Graph, Query, _reverse_edge + +def load_graph(data_dir, embed_dim): + rels, adj_lists, node_maps = pickle.load(open(data_dir+"/graph_data.pkl", "rb")) + node_maps = {m : {n : i for i, n in enumerate(id_list)} for m, id_list in node_maps.items()} + for m in node_maps: + node_maps[m][-1] = -1 + feature_dims = {m : embed_dim for m in rels} + feature_modules = {m : torch.nn.Embedding(len(node_maps[m])+1, embed_dim) for m in rels} + for mode in rels: + feature_modules[mode].weight.data.normal_(0, 1./embed_dim) + features = lambda nodes, mode : feature_modules[mode]( + torch.autograd.Variable(torch.LongTensor([node_maps[mode][n] for n in nodes])+1)) + graph = Graph(features, feature_dims, rels, adj_lists) + return graph, feature_modules, node_maps + +def sample_new_clean(data_dir): + graph_loader = lambda : load_graph(data_dir, 10)[0] + sample_clean_test(graph_loader, data_dir) + +def clean_test(): + test_edges = pickle.load(open("/dfs/scratch0/nqe-bio/test_edges.pkl", "rb")) + val_edges = pickle.load(open("/dfs/scratch0/nqe-bio/val_edges.pkl", "rb")) + deleted_edges = set([q[0][1] for q in test_edges] + [_reverse_edge(q[0][1]) for q in test_edges] + + [q[0][1] for q in val_edges] + [_reverse_edge(q[0][1]) for q in val_edges]) + + for i in range(2,4): + for kind in ["val", "test"]: + if kind == "val": + to_keep = 1000 + else: + to_keep = 10000 + test_queries = load_queries_by_type("/dfs/scratch0/nqe-bio/{:s}_queries_{:d}-split.pkl".format(kind, i), keep_graph=True) + print("Loaded", i, kind) + for query_type in test_queries: + test_queries[query_type] = [q for q in test_queries[query_type] if len(q.get_edges().intersection(deleted_edges)) > 0] + test_queries[query_type] = test_queries[query_type][:to_keep] + test_queries = [q.serialize() for queries in test_queries.values() for q in queries] + pickle.dump(test_queries, open("/dfs/scratch0/nqe-bio/{:s}_queries_{:d}-clean.pkl".format(kind, i), "wb"), protocol=pickle.HIGHEST_PROTOCOL) + print("Finished", i, kind) + + + +def make_train_test_edge_data(data_dir): + print("Loading graph...") + graph, _, _ = load_graph(data_dir, 10) + print("Getting all edges...") + edges = graph.get_all_edges() + split_point = int(0.1*len(edges)) + val_test_edges = edges[:split_point] + print("Getting negative samples...") + val_test_edge_negsamples = [graph.get_negative_edge_samples(e, 100) for e in val_test_edges] + print("Making and storing test queries.") + val_test_edge_queries = [Query(("1-chain", val_test_edges[i]), val_test_edge_negsamples[i], None, 100) for i in range(split_point)] + val_split_point = int(0.1*len(val_test_edge_queries)) + val_queries = val_test_edge_queries[:val_split_point] + test_queries = val_test_edge_queries[val_split_point:] + pickle.dump([q.serialize() for q in val_queries], open(data_dir+"/val_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump([q.serialize() for q in test_queries], open(data_dir+"/test_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) + + print("Removing test edges...") + graph.remove_edges(val_test_edges) + print("Making and storing train queries.") + train_edges = graph.get_all_edges() + train_queries = [Query(("1-chain", e), None, None) for e in train_edges] + pickle.dump([q.serialize() for q in train_queries], open(data_dir+"/train_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) + +def _discard_negatives(file_name, small_prop=0.9): + queries = pickle.load(open(file_name, "rb")) +# queries = [q if random.random() > small_prop else (q[0],[random.choice(tuple(q[1]))], None if q[2] is None else [random.choice(tuple(q[2]))]) for q in queries] + queries = [q if random.random() > small_prop else (q[0],[random.choice(list(q[1]))], None if q[2] is None else [random.choice(list(q[2]))]) for q in queries] + pickle.dump(queries, open(file_name.split(".")[0] + "-split.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + print("Finished", file_name) + +def discard_negatives(data_dir): + _discard_negatives(data_dir + "/val_edges.pkl") + _discard_negatives(data_dir + "/test_edges.pkl") + for i in range(2,4): + _discard_negatives(data_dir + "/val_queries_{:d}.pkl".format(i)) + _discard_negatives(data_dir + "/test_queries_{:d}.pkl".format(i)) + + +def make_train_test_query_data(data_dir): + graph, _, _ = load_graph(data_dir, 10) + queries_2, queries_3 = parallel_sample(graph, 20, 50000, data_dir, test=False) + t_queries_2, t_queries_3 = parallel_sample(graph, 20, 5000, data_dir, test=True) + t_queries_2 = list(set(t_queries_2) - set(queries_2)) + t_queries_3 = list(set(t_queries_3) - set(queries_3)) + pickle.dump([q.serialize() for q in queries_2], open(data_dir + "/train_queries_2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump([q.serialize() for q in queries_3], open(data_dir + "/train_queries_3.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump([q.serialize() for q in t_queries_2[10000:]], open(data_dir + "/test_queries_2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump([q.serialize() for q in t_queries_3[10000:]], open(data_dir + "/test_queries_3.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump([q.serialize() for q in t_queries_2[:10000]], open(data_dir + "/val_queries_2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + pickle.dump([q.serialize() for q in t_queries_3[:10000]], open(data_dir + "/val_queries_3.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) + + +if __name__ == "__main__": + #make_train_test_query_data("/dfs/scratch0/nqe-bio/") + #make_train_test_edge_data("/dfs/scratch0/nqe-bio/") + sample_new_clean("/dfs/scratch0/nqe-bio/") + #clean_test() diff --git a/encoders.py b/encoders.py new file mode 100644 index 0000000..772a174 --- /dev/null +++ b/encoders.py @@ -0,0 +1,146 @@ +import torch +import torch.nn as nn +from torch.nn import init +import torch.nn.functional as F + +""" +Set of modules for encoding nodes. +These modules take as input node ids and output embeddings. +""" + +class DirectEncoder(nn.Module): + """ + Encodes a node as a embedding via direct lookup. + (i.e., this is just like basic node2vec or matrix factorization) + """ + def __init__(self, features, feature_modules): + """ + Initializes the model for a specific graph. + + features -- function mapping (node_list, features, offset) to feature values + see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + feature_modules -- This should be a map from mode -> torch.nn.EmbeddingBag + """ + super(DirectEncoder, self).__init__() + for name, module in feature_modules.items(): + self.add_module("feat-"+name, module) + self.features = features + + def forward(self, nodes, mode, offset=None, **kwargs): + """ + Generates embeddings for a batch of nodes. + + nodes -- list of nodes + mode -- string desiginating the mode of the nodes + offsets -- specifies how the embeddings are aggregated. + see torch.nn.EmbeddingBag for format. + No aggregation if offsets is None + """ + + if offset is None: + embeds = self.features(nodes, mode).t() + norm = embeds.norm(p=2, dim=0, keepdim=True) + return embeds.div(norm.expand_as(embeds)) + else: + return self.features(nodes, mode, offset).t() + +class Encoder(nn.Module): + """ + Encodes a node's using a GCN/GraphSage approach + """ + def __init__(self, features, feature_dims, + out_dims, relations, adj_lists, aggregator, + base_model=None, cuda=False, + layer_norm=False, + feature_modules={}): + """ + Initializes the model for a specific graph. + + features -- function mapping (node_list, features, offset) to feature values + see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + feature_dims -- output dimension of each of the feature functions. + out_dims -- embedding dimensions for each mode (i.e., output dimensions) + relations -- map from mode -> out_going_relations + adj_lists -- map from relation_tuple -> node -> list of node's neighbors + base_model -- if features are from another encoder, pass it here for training + cuda -- whether or not to move params to the GPU + feature_modules -- if features come from torch.nn module, pass the modules here for training + """ + + super(Encoder, self).__init__() + + self.features = features + self.feat_dims = feature_dims + self.adj_lists = adj_lists + self.relations = relations + self.aggregator = aggregator + for name, module in feature_modules.items(): + self.add_module("feat-"+name, module) + if base_model != None: + self.base_model = base_model + + self.out_dims = out_dims + self.cuda = cuda + self.aggregator.cuda = cuda + self.layer_norm = layer_norm + self.compress_dims = {} + for source_mode in relations: + self.compress_dims[source_mode] = self.feat_dims[source_mode] + for (to_mode, _) in relations[source_mode]: + self.compress_dims[source_mode] += self.feat_dims[to_mode] + + self.self_params = {} + self.compress_params = {} + self.lns = {} + for mode, feat_dim in self.feat_dims.items(): + if self.layer_norm: + self.lns[mode] = LayerNorm(out_dims[mode]) + self.add_module(mode+"_ln", self.lns[mode]) + self.compress_params[mode] = nn.Parameter( + torch.FloatTensor(out_dims[mode], self.compress_dims[mode])) + init.xavier_uniform(self.compress_params[mode]) + self.register_parameter(mode+"_compress", self.compress_params[mode]) + + def forward(self, nodes, mode, keep_prob=0.5, max_keep=10): + """ + Generates embeddings for a batch of nodes. + + nodes -- list of nodes + mode -- string desiginating the mode of the nodes + """ + self_feat = self.features(nodes, mode).t() + neigh_feats = [] + for to_r in self.relations[mode]: + rel = (mode, to_r[1], to_r[0]) + to_neighs = [[-1] if node == -1 else self.adj_lists[rel][node] for node in nodes] + + # Special null neighbor for nodes with no edges of this type + to_neighs = [[-1] if len(l) == 0 else l for l in to_neighs] + to_feats = self.aggregator.forward(to_neighs, rel, keep_prob, max_keep) + to_feats = to_feats.t() + neigh_feats.append(to_feats) + + neigh_feats.append(self_feat) + combined = torch.cat(neigh_feats, dim=0) + combined = self.compress_params[mode].mm(combined) + if self.layer_norm: + combined = self.lns[mode](combined.t()).t() + combined = F.relu(combined) + return combined + + +class LayerNorm(nn.Module): + """ + Simple layer norm object optionally used with the convolutional encoder. + """ + + def __init__(self, feature_dim, eps=1e-6): + super(LayerNorm, self).__init__() + self.gamma = nn.Parameter(torch.ones((feature_dim,))) + self.beta = nn.Parameter(torch.zeros((feature_dim,))) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.gamma * (x - mean) / (std + self.eps) + self.beta diff --git a/graph.py b/graph.py new file mode 100644 index 0000000..8e3e4cc --- /dev/null +++ b/graph.py @@ -0,0 +1,695 @@ +from collections import OrderedDict, defaultdict +import random + +def _reverse_relation(relation): + return (relation[-1], relation[1], relation[0]) + +def _reverse_edge(edge): + return (edge[-1], _reverse_relation(edge[1]), edge[0]) + + +class Formula(): + + def __init__(self, query_type, rels): + self.query_type = query_type + self.target_mode = rels[0][0] + self.rels = rels + if query_type == "1-chain" or query_type == "2-chain" or query_type == "3-chain": + self.anchor_modes = (rels[-1][-1],) + elif query_type == "2-inter" or query_type == "3-inter": + self.anchor_modes = tuple([rel[-1] for rel in rels]) + elif query_type == "3-inter_chain": + self.anchor_modes = (rels[0][-1], rels[1][-1][-1]) + elif query_type == "3-chain_inter": + self.anchor_modes = (rels[1][0][-1], rels[1][1][-1]) + + def __hash__(self): + return hash((self.query_type, self.rels)) + + def __eq__(self, other): + return ((self.query_type, self.rels)) == ((other.query_type, other.rels)) + + def __neq__(self, other): + return ((self.query_type, self.rels)) != ((other.query_type, other.rels)) + + def __str__(self): + return self.query_type + ": " + str(self.rels) + +class Query(): + + def __init__(self, query_graph, neg_samples, hard_neg_samples, neg_sample_max=100, keep_graph=False): + query_type = query_graph[0] + if query_type == "1-chain" or query_type == "2-chain" or query_type == "3-chain": + self.formula = Formula(query_type, tuple([query_graph[i][1] for i in range(1, len(query_graph))])) + self.anchor_nodes = (query_graph[-1][-1],) + elif query_type == "2-inter" or query_type == "3-inter": + self.formula = Formula(query_type, tuple([query_graph[i][1] for i in range(1, len(query_graph))])) + self.anchor_nodes = tuple([query_graph[i][-1] for i in range(1, len(query_graph))]) + elif query_type == "3-inter_chain": + self.formula = Formula(query_type, (query_graph[1][1], (query_graph[2][0][1], query_graph[2][1][1]))) + self.anchor_nodes = (query_graph[1][-1], query_graph[2][-1][-1]) + elif query_type == "3-chain_inter": + self.formula = Formula(query_type, (query_graph[1][1], (query_graph[2][0][1], query_graph[2][1][1]))) + self.anchor_nodes = (query_graph[2][0][-1], query_graph[2][1][-1]) + self.target_node = query_graph[1][0] + if keep_graph: + self.query_graph = query_graph + else: + self.query_graph = None + if not neg_samples is None: + self.neg_samples = list(neg_samples) if len(neg_samples) < neg_sample_max else random.sample(neg_samples, neg_sample_max) + else: + self.neg_samples = None + if not hard_neg_samples is None: + self.hard_neg_samples = list(hard_neg_samples) if len(hard_neg_samples) <= neg_sample_max else random.sample(hard_neg_samples, neg_sample_max) + else: + self.hard_neg_samples = None + + def contains_edge(self, edge): + if self.query_graph is None: + raise Exception("Can only test edge contain if graph is kept. Reinit with keep_graph=True") + edges = self.query_graph[1:] + if "inter_chain" in self.query_graph[0] or "chain_inter" in self.query_graph[0]: + edges = (edges[0], edges[1][0], edges[1][1]) + return edge in edges or (edge[1], _reverse_relation(edge[1]), edge[0]) in edges + + def get_edges(self): + if self.query_graph is None: + raise Exception("Can only test edge contain if graph is kept. Reinit with keep_graph=True") + edges = self.query_graph[1:] + if "inter_chain" in self.query_graph[0] or "chain_inter" in self.query_graph[0]: + edges = (edges[0], edges[1][0], edges[1][1]) + return set(edges).union(set([(e[-1], _reverse_relation(e[1]), e[0]) for e in edges])) + + def __hash__(self): + return hash((self.formula, self.target_node, self.anchor_nodes)) + + def __eq__(self, other): + return (self.formula, self.target_node, self.anchor_nodes) == (other.formula, other.target_node, other.anchor_nodes) + + def __neq__(self, other): + return self.__hash__() != other.__hash__() + + def serialize(self): + if self.query_graph is None: + raise Exception("Cannot serialize query loaded with query graph!") + return (self.query_graph, self.neg_samples, self.hard_neg_samples) + + @staticmethod + def deserialize(serial_info, keep_graph=False): + return Query(serial_info[0], serial_info[1], serial_info[2], None if serial_info[1] is None else len(serial_info[1]), keep_graph=keep_graph) + + + +class Graph(): + """ + Simple container for heteregeneous graph data. + """ + def __init__(self, features, feature_dims, relations, adj_lists): + self.features = features + self.feature_dims = feature_dims + self.relations = relations + self.adj_lists = adj_lists + self.full_sets = defaultdict(set) + self.full_lists = {} + self.meta_neighs = defaultdict(dict) + for rel, adjs in self.adj_lists.items(): + full_set = set(self.adj_lists[rel].keys()) + self.full_sets[rel[0]] = self.full_sets[rel[0]].union(full_set) + for mode, full_set in self.full_sets.items(): + self.full_lists[mode] = list(full_set) + self._cache_edge_counts() + self._make_flat_adj_lists() + + def _make_flat_adj_lists(self): + self.flat_adj_lists = defaultdict(lambda : defaultdict(list)) + for rel, adjs in self.adj_lists.items(): + for node, neighs in adjs.items(): + self.flat_adj_lists[rel[0]][node].extend([(rel, neigh) for neigh in neighs]) + + def _cache_edge_counts(self): + self.edges = 0. + self.rel_edges = {} + for r1 in self.relations: + for r2 in self.relations[r1]: + rel = (r1,r2[1], r2[0]) + self.rel_edges[rel] = 0. + for adj_list in self.adj_lists[rel].values(): + self.rel_edges[rel] += len(adj_list) + self.edges += 1. + self.rel_weights = OrderedDict() + self.mode_edges = defaultdict(float) + self.mode_weights = OrderedDict() + for rel, edge_count in self.rel_edges.items(): + self.rel_weights[rel] = edge_count / self.edges + self.mode_edges[rel[0]] += edge_count + for mode, edge_count in self.mode_edges.items(): + self.mode_weights[mode] = edge_count / self.edges + + def remove_edges(self, edge_list): + for edge in edge_list: + try: + self.adj_lists[edge[1]][edge[0]].remove(edge[-1]) + except Exception: + continue + + try: + self.adj_lists[_reverse_relation(edge[1])][edge[-1]].remove(edge[0]) + except Exception: + continue + self.meta_neighs = defaultdict(dict) + self._cache_edge_counts() + self._make_flat_adj_lists() + + def get_all_edges(self, seed=0, exclude_rels=set([])): + """ + Returns all edges in the form (node1, relation, node2) + """ + edges = [] + random.seed(seed) + for rel, adjs in self.adj_lists.items(): + if rel in exclude_rels: + continue + for node, neighs in adjs.items(): + edges.extend([(node, rel, neigh) for neigh in neighs if neigh != -1]) + random.shuffle(edges) + return edges + + def get_all_edges_byrel(self, seed=0, + exclude_rels=set([])): + random.seed(seed) + edges = defaultdict(list) + for rel, adjs in self.adj_lists.items(): + if rel in exclude_rels: + continue + for node, neighs in adjs.items(): + edges[(rel,)].extend([(node, neigh) for neigh in neighs if neigh != -1]) + + def get_negative_edge_samples(self, edge, num, rejection_sample=True): + if rejection_sample: + neg_nodes = set([]) + counter = 0 + while len(neg_nodes) < num: + neg_node = random.choice(self.full_lists[edge[1][0]]) + if not neg_node in self.adj_lists[_reverse_relation(edge[1])][edge[2]]: + neg_nodes.add(neg_node) + counter += 1 + if counter > 100*num: + return self.get_negative_edge_samples(edge, num, rejection_sample=False) + else: + neg_nodes = self.full_sets[edge[1][0]] - self.adj_lists[_reverse_relation(edge[1])][edge[2]] + neg_nodes = list(neg_nodes) if len(neg_nodes) <= num else random.sample(list(neg_nodes), num) + return neg_nodes + + def sample_test_queries(self, train_graph, q_types, samples_per_type, neg_sample_max, verbose=True): + queries = [] + for q_type in q_types: + sampled = 0 + while sampled < samples_per_type: + q = self.sample_query_subgraph_bytype(q_type) + if q is None or not train_graph._is_negative(q, q[1][0], False): + continue + negs, hard_negs = self.get_negative_samples(q) + if negs is None or ("inter" in q[0] and hard_negs is None): + continue + query = Query(q, negs, hard_negs, neg_sample_max=neg_sample_max, keep_graph=True) + queries.append(query) + sampled += 1 + if sampled % 1000 == 0 and verbose: + print("Sampled", sampled) + return queries + + def sample_queries(self, arity, num_samples, neg_sample_max, verbose=True): + sampled = 0 + queries = [] + while sampled < num_samples: + q = self.sample_query_subgraph(arity) + if q is None: + continue + negs, hard_negs = self.get_negative_samples(q) + if negs is None or ("inter" in q[0] and hard_negs is None): + continue + query = Query(q, negs, hard_negs, neg_sample_max=neg_sample_max, keep_graph=True) + queries.append(query) + sampled += 1 + if sampled % 1000 == 0 and verbose: + print("Sampled", sampled) + return queries + + + def get_negative_samples(self, query): + if query[0] == "3-chain" or query[0] == "2-chain": + edges = query[1:] + rels = [_reverse_relation(edge[1]) for edge in edges[::-1]] + meta_neighs = self.get_metapath_neighs(query[-1][-1], tuple(rels)) + negative_samples = self.full_sets[query[1][1][0]] - meta_neighs + if len(negative_samples) == 0: + return None, None + else: + return negative_samples, None + elif query[0] == "2-inter" or query[0] == "3-inter": + rel_1 = _reverse_relation(query[1][1]) + union_neighs = self.adj_lists[rel_1][query[1][-1]] + inter_neighs = self.adj_lists[rel_1][query[1][-1]] + for i in range(2,len(query)): + rel = _reverse_relation(query[i][1]) + union_neighs = union_neighs.union(self.adj_lists[rel][query[i][-1]]) + inter_neighs = inter_neighs.intersection(self.adj_lists[rel][query[i][-1]]) + neg_samples = self.full_sets[query[1][1][0]] - inter_neighs + hard_neg_samples = union_neighs - inter_neighs + if len(neg_samples) == 0 or len(hard_neg_samples) == 0: + return None, None + return neg_samples, hard_neg_samples + elif query[0] == "3-inter_chain": + rel_1 = _reverse_relation(query[1][1]) + union_neighs = self.adj_lists[rel_1][query[1][-1]] + inter_neighs = self.adj_lists[rel_1][query[1][-1]] + chain_rels = [_reverse_relation(edge[1]) for edge in query[2][::-1]] + chain_neighs = self.get_metapath_neighs(query[2][-1][-1], tuple(chain_rels)) + union_neighs = union_neighs.union(chain_neighs) + inter_neighs = inter_neighs.intersection(chain_neighs) + neg_samples = self.full_sets[query[1][1][0]] - inter_neighs + hard_neg_samples = union_neighs - inter_neighs + if len(neg_samples) == 0 or len(hard_neg_samples) == 0: + return None, None + return neg_samples, hard_neg_samples + elif query[0] == "3-chain_inter": + inter_rel_1 = _reverse_relation(query[-1][0][1]) + inter_neighs_1 = self.adj_lists[inter_rel_1][query[-1][0][-1]] + inter_rel_2 = _reverse_relation(query[-1][1][1]) + inter_neighs_2 = self.adj_lists[inter_rel_2][query[-1][1][-1]] + + inter_neighs = inter_neighs_1.intersection(inter_neighs_2) + union_neighs = inter_neighs_1.union(inter_neighs_2) + rel = _reverse_relation(query[1][1]) + pos_nodes = set([n for neigh in inter_neighs for n in self.adj_lists[rel][neigh]]) + union_pos_nodes = set([n for neigh in union_neighs for n in self.adj_lists[rel][neigh]]) + neg_samples = self.full_sets[query[1][1][0]] - pos_nodes + hard_neg_samples = union_pos_nodes - pos_nodes + if len(neg_samples) == 0 or len(hard_neg_samples) == 0: + return None, None + return neg_samples, hard_neg_samples + + def sample_edge(self, node, mode): + rel, neigh = random.choice(self.flat_adj_lists[mode][node]) + edge = (node, rel, neigh) + return edge + + def sample_query_subgraph_bytype(self, q_type, start_node=None): + if start_node is None: + start_rel = random.choice(self.adj_lists.keys()) + node = random.choice(self.adj_lists[start_rel].keys()) + mode = start_rel[0] + else: + node, mode = start_node + + if q_type[0] == "3": + if q_type == "3-chain" or q_type == "3-chain_inter": + num_edges = 1 + elif q_type == "3-inter_chain": + num_edges = 2 + elif q_type == "3-inter": + num_edges = 3 + if num_edges > len(self.flat_adj_lists[mode][node]): + return None + if num_edges == 1: + rel, neigh = random.choice(self.flat_adj_lists[mode][node]) + edge = (node, rel, neigh) + next_query = self.sample_query_subgraph_bytype( + "2-chain" if q_type == "3-chain" else "2-inter", start_node=(neigh, rel[0])) + if next_query is None: + return None + if next_query[0] == "2-chain": + return ("3-chain", edge, next_query[1], next_query[2]) + else: + return ("3-chain_inter", edge, (next_query[1], next_query[2])) + elif num_edges == 2: + rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) + edge_1 = (node, rel_1, neigh_1) + neigh_2 = neigh_1 + rel_2 = rel_1 + while (neigh_1, rel_1) == (neigh_2, rel_2): + rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) + edge_2 = (node, rel_2, neigh_2) + return ("3-inter_chain", edge_1, (edge_2, self.sample_edge(neigh_2, rel_2[-1]))) + elif num_edges == 3: + rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) + edge_1 = (node, rel_1, neigh_1) + neigh_2 = neigh_1 + rel_2 = rel_1 + while (rel_1, neigh_1) == (rel_2, neigh_2): + rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) + edge_2 = (node, rel_2, neigh_2) + neigh_3 = neigh_1 + rel_3 = rel_1 + while ((rel_1, neigh_1) == (rel_3, neigh_3)) or ((rel_2, neigh_2) == (rel_3, neigh_3)): + rel_3, neigh_3 = random.choice(self.flat_adj_lists[mode][node]) + edge_3 = (node, rel_3, neigh_3) + return ("3-inter", edge_1, edge_2, edge_3) + + if q_type[0] == "2": + num_edges = 1 if q_type == "2-chain" else 2 + if num_edges > len(self.flat_adj_lists[mode][node]): + return None + if num_edges == 1: + rel, neigh = random.choice(self.flat_adj_lists[mode][node]) + edge = (node, rel, neigh) + return ("2-chain", edge, self.sample_edge(neigh, rel[-1])) + elif num_edges == 2: + rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) + edge_1 = (node, rel_1, neigh_1) + neigh_2 = neigh_1 + rel_2 = rel_1 + while (neigh_1, rel_1) == (neigh_2, rel_2): + rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) + edge_2 = (node, rel_2, neigh_2) + return ("2-inter", edge_1, edge_2) + + + def sample_query_subgraph(self, arity, start_node=None): + if start_node is None: + start_rel = random.choice(self.adj_lists.keys()) + node = random.choice(self.adj_lists[start_rel].keys()) + mode = start_rel[0] + else: + node, mode = start_node + if arity > 3 or arity < 2: + raise Exception("Only arity of at most 3 is supported for queries") + + if arity == 3: + # 1/2 prob of 1 edge, 1/4 prob of 2, 1/4 prob of 3 + num_edges = random.choice([1,1,2,3]) + if num_edges > len(self.flat_adj_lists[mode][node]): + return None + if num_edges == 1: + rel, neigh = random.choice(self.flat_adj_lists[mode][node]) + edge = (node, rel, neigh) + next_query = self.sample_query_subgraph(2, start_node=(neigh, rel[0])) + if next_query is None: + return None + if next_query[0] == "2-chain": + return ("3-chain", edge, next_query[1], next_query[2]) + else: + return ("3-chain_inter", edge, (next_query[1], next_query[2])) + elif num_edges == 2: + rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) + edge_1 = (node, rel_1, neigh_1) + neigh_2 = neigh_1 + rel_2 = rel_1 + while (neigh_1, rel_1) == (neigh_2, rel_2): + rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) + edge_2 = (node, rel_2, neigh_2) + return ("3-inter_chain", edge_1, (edge_2, self.sample_edge(neigh_2, rel_2[-1]))) + elif num_edges == 3: + rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) + edge_1 = (node, rel_1, neigh_1) + neigh_2 = neigh_1 + rel_2 = rel_1 + while (rel_1, neigh_1) == (rel_2, neigh_2): + rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) + edge_2 = (node, rel_2, neigh_2) + neigh_3 = neigh_1 + rel_3 = rel_1 + while ((rel_1, neigh_1) == (rel_3, neigh_3)) or ((rel_2, neigh_2) == (rel_3, neigh_3)): + rel_3, neigh_3 = random.choice(self.flat_adj_lists[mode][node]) + edge_3 = (node, rel_3, neigh_3) + return ("3-inter", edge_1, edge_2, edge_3) + + if arity == 2: + num_edges = random.choice([1,2]) + if num_edges > len(self.flat_adj_lists[mode][node]): + return None + if num_edges == 1: + rel, neigh = random.choice(self.flat_adj_lists[mode][node]) + edge = (node, rel, neigh) + return ("2-chain", edge, self.sample_edge(neigh, rel[-1])) + elif num_edges == 2: + rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) + edge_1 = (node, rel_1, neigh_1) + neigh_2 = neigh_1 + rel_2 = rel_1 + while (neigh_1, rel_1) == (neigh_2, rel_2): + rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) + edge_2 = (node, rel_2, neigh_2) + return ("2-inter", edge_1, edge_2) + + def get_metapath_neighs(self, node, rels): + if node in self.meta_neighs[rels]: + return self.meta_neighs[rels][node] + current_set = [node] + for rel in rels: + current_set = set([neigh for n in current_set for neigh in self.adj_lists[rel][n]]) + self.meta_neighs[rels][node] = current_set + return current_set + + ## TESTING CODE + + def _check_edge(self, query, i): + return query[i][-1] in self.adj_lists[query[i][1]][query[i][0]] + + def _is_subgraph(self, query, verbose): + if query[0] == "3-chain": + for i in range(3): + if not self._check_edge(query, i+1): + raise Exception(str(query)) + if not (query[1][-1] == query[2][0] and query[2][-1] == query[3][0]): + raise Exception(str(query)) + if query[0] == "2-chain": + for i in range(2): + if not self._check_edge(query, i+1): + raise Exception(str(query)) + if not query[1][-1] == query[2][0]: + raise Exception(str(query)) + if query[0] == "2-inter": + for i in range(2): + if not self._check_edge(query, i+1): + raise Exception(str(query)) + if not query[1][0] == query[2][0]: + raise Exception(str(query)) + if query[0] == "3-inter": + for i in range(3): + if not self._check_edge(query, i+1): + raise Exception(str(query)) + if not (query[1][0] == query[2][0] and query[2][0] == query[3][0]): + raise Exception(str(query)) + if query[0] == "3-inter_chain": + if not (self._check_edge(query, 1) and self._check_edge(query[2], 0) and self._check_edge(query[2], 1)): + raise Exception(str(query)) + if not (query[1][0] == query[2][0][0] and query[2][0][-1] == query[2][1][0]): + raise Exception(str(query)) + if query[0] == "3-chain_inter": + if not (self._check_edge(query, 1) and self._check_edge(query[2], 0) and self._check_edge(query[2], 1)): + raise Exception(str(query)) + if not (query[1][-1] == query[2][0][0] and query[2][0][0] == query[2][1][0]): + raise Exception(str(query)) + return True + + def _is_negative(self, query, neg_node, is_hard): + if query[0] == "2-chain": + query = (query[0], (neg_node, query[1][1], query[1][2]), query[2]) + if query[2][-1] in self.get_metapath_neighs(query[1][0], (query[1][1], query[2][1])): + return False + if query[0] == "3-chain": + query = (query[0], (neg_node, query[1][1], query[1][2]), query[2], query[3]) + if query[3][-1] in self.get_metapath_neighs(query[1][0], (query[1][1], query[2][1], query[3][1])): + return False + if query[0] == "2-inter": + query = (query[0], (neg_node, query[1][1], query[1][2]), (neg_node, query[2][1], query[2][2])) + if not is_hard: + if self._check_edge(query, 1) and self._check_edge(query, 2): + return False + else: + if (self._check_edge(query, 1) and self._check_edge(query, 2)) or not (self._check_edge(query, 1) or self._check_edge(query, 2)): + return False + if query[0] == "3-inter": + query = (query[0], (neg_node, query[1][1], query[1][2]), (neg_node, query[2][1], query[2][2]), (neg_node, query[3][1], query[3][2])) + if not is_hard: + if self._check_edge(query, 1) and self._check_edge(query, 2) and self._check_edge(query, 3): + return False + else: + if (self._check_edge(query, 1) and self._check_edge(query, 2) and self._check_edge(query, 3))\ + or not (self._check_edge(query, 1) or self._check_edge(query, 2) or self._check_edge(query, 3)): + return False + if query[0] == "3-inter_chain": + query = (query[0], (neg_node, query[1][1], query[1][2]), ((neg_node, query[2][0][1], query[2][0][2]), query[2][1])) + meta_check = lambda : query[2][-1][-1] in self.get_metapath_neighs(query[1][0], (query[2][0][1], query[2][1][1])) + neigh_check = lambda : self._check_edge(query, 1) + if not is_hard: + if meta_check() and neigh_check(): + return False + else: + if (meta_check() and neigh_check()) or not (meta_check() or neigh_check()): + return False + if query[0] == "3-chain_inter": + query = (query[0], (neg_node, query[1][1], query[1][2]), query[2]) + target_neigh = self.adj_lists[query[1][1]][neg_node] + neigh_1 = self.adj_lists[_reverse_relation(query[2][0][1])][query[2][0][-1]] + neigh_2 = self.adj_lists[_reverse_relation(query[2][1][1])][query[2][1][-1]] + if not is_hard: + if target_neigh in neigh_1.intersection(neigh_2): + return False + else: + if target_neigh in neigh_1.intersection(neigh_2) and not target_neigh in neigh_1.union(neigh_2): + return False + return True + + + + def _run_test(self, num_samples=1000): + for i in range(num_samples): + q = self.sample_query_subgraph(2) + if q is None: + continue + self._is_subgraph(q, True) + negs, hard_negs = self.get_negative_samples(q) + if not negs is None: + for n in negs: + self._is_negative(q, n, False) + if not hard_negs is None: + for n in hard_negs: + self._is_negative(q, n, True) + q = self.sample_query_subgraph(3) + if q is None: + continue + self._is_subgraph(q, True) + negs, hard_negs = self.get_negative_samples(q) + if not negs is None: + for n in negs: + self._is_negative(q, n, False) + if not hard_negs is None: + for n in hard_negs: + self._is_negative(q, n, True) + return True + + + """ + TO DELETE? + def sample_chain_from_node(self, length, node, rel): + rels = [rel] + for cur_len in range(length-1): + next_rel = random.choice(self.relations[rels[-1][-1]]) + rels.append((rels[-1][-1], next_rel[-1], next_rel[0])) + + rels = tuple(rels) + meta_neighs = self.get_metapath_neighs(node, rels) + rev_rel = _reverse_relation(rels[-1]) + full_set = self.full_sets[rev_rel] + diff_set = full_set - meta_neighs + if len(meta_neighs) == 0 or len(diff_set) == 0: + return None, None, None + chain = (node, random.choice(list(meta_neighs))) + neg_chain = (node, random.choice(list(diff_set))) + return chain, neg_chain, rels + + def sample_chain(self, length, start_mode): + rel = random.choice(self.relations[start_mode]) + rel = (start_mode, rel[-1], rel[0]) + if len(self.adj_lists[rel]) == 0: + return None, None, None + node = random.choice(self.adj_lists[rel].keys()) + return self.sample_chain_from_node(length, node, rel) + + def sample_chains(self, length, anchor_weights, num_samples): + sampled = 0 + graph_chains = defaultdict(list) + neg_chains = defaultdict(list) + while sampled < num_samples: + anchor_mode = anchor_weights.keys()[np.argmax(np.random.multinomial(1, anchor_weights.values()))] + chain, neg_chain, rels = self.sample_chain(length, anchor_mode) + if chain is None: + continue + graph_chains[rels].append(chain) + neg_chains[rels].append(neg_chain) + sampled += 1 + return graph_chains, neg_chains + + + def sample_polytree_rootinter(self, length, target_mode, try_out=100): + num_chains = random.randint(2,length) + added = 0 + nodes = [] + rels_list = [] + + for i in range(num_chains): + remaining = length-added-num_chains + if i != num_chains - 1: + remaining = remaining if remaining == 0 else random.randint(0, remaining) + added += remaining + chain_len = 1 + remaining + if i == 0: + chain, _, rels = self.sample_chain(chain_len, target_mode) + try_count = 0 + while chain is None and try_count <= try_out: + chain, _, rels = self.sample_chain(chain_len, target_mode) + try_count += 1 + + if chain is None: + return None, None, None, None, None + target_node = chain[0] + nodes.append(chain[-1]) + rels_list.append(tuple([_reverse_relation(rel) for rel in rels[::-1]])) + else: + rel = random.choice([r for r in self.relations[target_mode] + if len(self.adj_lists[(target_mode, r[-1], r[0])][target_node]) > 0]) + rel = (target_mode, rel[-1], rel[0]) + chain, _, rels = self.sample_chain_from_node(chain_len, target_node, rel) + try_count = 0 + while chain is None and try_count <= try_out: + chain, _, rels = self.sample_chain_from_node(chain_len, target_node, rel) + if chain is None: + try_count += 1 + elif chain[-1] in nodes: + chain = None + if chain is None: + return None, None, None, None, None + nodes.append(chain[-1]) + rels_list.append(tuple([_reverse_relation(rel) for rel in rels[::-1]])) + + for i in range(len(nodes)): + meta_neighs = self.get_metapath_neighs(nodes[i], rels_list[i]) + if i == 0: + meta_neighs_inter = meta_neighs + meta_neighs_union = meta_neighs + else: + meta_neighs_inter = meta_neighs_inter.intersection(meta_neighs) + meta_neighs_union = meta_neighs_union.union(meta_neighs) + hard_neg_nodes = list(meta_neighs_union-meta_neighs_inter) + neg_nodes = list(self.full_sets[rels[0]]-meta_neighs_inter) + if len(neg_nodes) == 0: + return None, None, None, None, None + if len(hard_neg_nodes) == 0: + return None, None, None, None, None + + return target_node, neg_nodes, hard_neg_nodes, tuple(nodes), tuple(rels_list) + + + def sample_polytrees_parallel(self, length, thread_samples, threads, try_out=100): + pool = Pool(threads) + sample_func = partial(self.sample_polytree, length) + sizes = [thread_samples for _ in range(threads)] + results = pool.map(sample_func, sizes) + polytrees = {} + neg_polytrees = {} + hard_neg_polytrees = {} + for p, n, h in results: + polytrees.update(p) + neg_polytrees.update(n) + hard_neg_polytrees.updarte(h) + return polytrees, neg_polytrees, hard_neg_polytrees + + def sample_polytrees(self, length, num_samples, try_out=1): + samples = 0 + polytrees = defaultdict(list) + neg_polytrees = defaultdict(list) + hard_neg_polytrees = defaultdict(list) + while samples < num_samples: + t, n, h_n, nodes, rels = self.sample_polytree(length, random.choice(self.relations.keys())) + if t is None: + continue + samples += 1 + polytrees[rels].append((t, nodes)) + neg_polytrees[rels].append((n, nodes)) + hard_neg_polytrees[rels].append((h_n, nodes)) + return polytrees, neg_polytrees, hard_neg_polytrees + + """ diff --git a/model.py b/model.py new file mode 100644 index 0000000..216a0cb --- /dev/null +++ b/model.py @@ -0,0 +1,189 @@ +import torch +import torch.nn as nn +import numpy as np + +import random +from netquery.graph import _reverse_relation + +EPS = 10e-6 + +""" +End-to-end autoencoder models for representation learning on +heteregenous graphs/networks +""" + +class MetapathEncoderDecoder(nn.Module): + """ + Encoder decoder model that reasons over metapaths + """ + + def __init__(self, graph, enc, dec): + """ + graph -- simple graph object; see graph.py + enc --- an encoder module that generates embeddings (see encoders.py) + dec --- an decoder module that predicts compositional relationships, i.e. metapaths, between nodes given embeddings. (see decoders.py) + Note that the decoder must be an *compositional/metapath* decoder (i.e., with name Metapath*.py) + """ + super(MetapathEncoderDecoder, self).__init__() + self.enc = enc + self.dec = dec + self.graph = graph + + def forward(self, nodes1, nodes2, rels): + """ + Returns a vector of 'relationship scores' for pairs of nodes being connected by the given metapath (sequence of relations). + Essentially, the returned scores are the predicted likelihood of the node pairs being connected + by the given metapath, where the pairs are given by the ordering in nodes1 and nodes2, + i.e. the first node id in nodes1 is paired with the first node id in nodes2. + """ + return self.dec.forward(self.enc.forward(nodes1, rels[0][0]), + self.enc.forward(nodes2, rels[-1][-1]), + rels) + + def margin_loss(self, nodes1, nodes2, rels): + """ + Standard max-margin based loss function. + Maximizes relationaship scores for true pairs vs negative samples. + """ + affs = self.forward(nodes1, nodes2, rels) + neg_nodes = [random.randint(1,len(self.graph.adj_lists[_reverse_relation[rels[-1]]])-1) for _ in range(len(nodes1))] + neg_affs = self.forward(nodes1, neg_nodes, + rels) + margin = 1 - (affs - neg_affs) + margin = torch.clamp(margin, min=0) + loss = margin.mean() + return loss + +class QueryEncoderDecoder(nn.Module): + """ + Encoder decoder model that reasons about edges, metapaths and intersections + """ + + def __init__(self, graph, enc, path_dec, inter_dec): + super(QueryEncoderDecoder, self).__init__() + self.enc = enc + self.path_dec = path_dec + self.inter_dec = inter_dec + self.graph = graph + self.cos = nn.CosineSimilarity(dim=0) + + def forward(self, formula, queries, source_nodes): + if formula.query_type == "1-chain" or formula.query_type == "2-chain" or formula.query_type == "3-chain": + # a chain is simply a call to the path decoder + return self.path_dec.forward( + self.enc.forward(source_nodes, formula.target_mode), + self.enc.forward([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]), + formula.rels) + elif formula.query_type == "2-inter" or formula.query_type == "3-inter" or formula.query_type == "3-inter_chain": + target_embeds = self.enc(source_nodes, formula.target_mode) + + embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]) + embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[0])) + + embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1]) + if len(formula.rels[1]) == 2: + for i_rel in formula.rels[1][::-1]: + embeds2 = self.path_dec.project(embeds2, _reverse_relation(i_rel)) + else: + embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1])) + + if formula.query_type == "3-inter": + embeds3 = self.enc([query.anchor_nodes[2] for query in queries], formula.anchor_modes[2]) + embeds3 = self.path_dec.project(embeds3, _reverse_relation(formula.rels[2])) + + query_intersection = self.inter_dec(embeds1, embeds2, formula.target_mode, embeds3) + else: + query_intersection = self.inter_dec(embeds1, embeds2, formula.target_mode) + scores = self.cos(target_embeds, query_intersection) + return scores + elif formula.query_type == "3-chain_inter": + target_embeds = self.enc(source_nodes, formula.target_mode) + + embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]) + embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[1][0])) + embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1]) + embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1][1])) + query_intersection = self.inter_dec(embeds1, embeds2, formula.rels[0][-1]) + query_intersection = self.path_dec.project(query_intersection, _reverse_relation(formula.rels[0])) + scores = self.cos(target_embeds, query_intersection) + return scores + + + def margin_loss(self, formula, queries, hard_negatives=False, margin=1): + if not "inter" in formula.query_type and hard_negatives: + raise Exception("Hard negative examples can only be used with intersection queries") + elif hard_negatives: + neg_nodes = [random.choice(query.hard_neg_samples) for query in queries] + elif formula.query_type == "1-chain": + neg_nodes = [random.choice(self.graph.full_lists[formula.target_mode]) for _ in queries] + else: + neg_nodes = [random.choice(query.neg_samples) for query in queries] + + affs = self.forward(formula, queries, [query.target_node for query in queries]) + neg_affs = self.forward(formula, queries, neg_nodes) + loss = margin - (affs - neg_affs) + loss = torch.clamp(loss, min=0) + loss = loss.mean() + return loss + +class SoftAndEncoderDecoder(nn.Module): + """ + Encoder decoder model that reasons about edges, metapaths and intersections + """ + + def __init__(self, graph, enc, path_dec): + super(SoftAndEncoderDecoder, self).__init__() + self.enc = enc + self.path_dec = path_dec + self.graph = graph + self.cos = nn.CosineSimilarity(dim=0) + + def forward(self, formula, queries, source_nodes): + if formula.query_type == "1-chain": + # a chain is simply a call to the path decoder + return self.path_dec.forward( + self.enc.forward(source_nodes, formula.target_mode), + self.enc.forward([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]), + formula.rels) + elif formula.query_type == "2-inter" or formula.query_type == "3-inter": + target_embeds = self.enc(source_nodes, formula.target_mode) + + embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]) + embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[0])) + + embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1]) + if len(formula.rels[1]) == 2: + for i_rel in formula.rels[1][::-1]: + embeds2 = self.path_dec.project(embeds2, _reverse_relation(i_rel)) + else: + embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1])) + + scores1 = self.cos(target_embeds, embeds1) + scores2 = self.cos(target_embeds, embeds2) + if formula.query_type == "3-inter": + embeds3 = self.enc([query.anchor_nodes[2] for query in queries], formula.anchor_modes[2]) + embeds3 = self.path_dec.project(embeds3, _reverse_relation(formula.rels[2])) + scores3 = self.cos(target_embeds, embeds2) + scores = scores1 * scores2 * scores3 + else: + scores = scores1 * scores2 + return scores + else: + raise Exception("Query type not supported for this model.") + + def margin_loss(self, formula, queries, hard_negatives=False, margin=1): + if not "inter" in formula.query_type and hard_negatives: + raise Exception("Hard negative examples can only be used with intersection queries") + elif hard_negatives: + neg_nodes = [random.choice(query.hard_neg_samples) for query in queries] + elif formula.query_type == "1-chain": + neg_nodes = [random.choice(self.graph.full_lists[formula.target_mode]) for _ in queries] + else: + neg_nodes = [random.choice(query.neg_samples) for query in queries] + + affs = self.forward(formula, queries, [query.target_node for query in queries]) + neg_affs = self.forward(formula, queries, neg_nodes) + loss = margin - (affs - neg_affs) + loss = torch.clamp(loss, min=0) + loss = loss.mean() + return loss diff --git a/train.py b/train.py new file mode 100644 index 0000000..808f695 --- /dev/null +++ b/train.py @@ -0,0 +1,81 @@ +from argparse import ArgumentParser + +from netquery.utils import * +from netquery.bio.data_utils import load_graph +from netquery.data_utils import load_queries_by_formula, load_test_queries_by_formula +from netquery.model import QueryEncoderDecoder +from netquery.train_helpers import run_train + +from torch import optim + +parser = ArgumentParser() +parser.add_argument("--embed_dim", type=int, default=128) +parser.add_argument("--data_dir", type=str, default="./bio_data/") +parser.add_argument("--lr", type=float, default=0.01) +parser.add_argument("--depth", type=int, default=0) +parser.add_argument("--batch_size", type=int, default=512) +parser.add_argument("--max_iter", type=int, default=100000000) +parser.add_argument("--max_burn_in", type=int, default=1000000) +parser.add_argument("--val_every", type=int, default=5000) +parser.add_argument("--tol", type=float, default=0.0001) +parser.add_argument("--cuda", action='store_true') +parser.add_argument("--log_dir", type=str, default="./") +parser.add_argument("--model_dir", type=str, default="./") +parser.add_argument("--decoder", type=str, default="bilinear") +parser.add_argument("--inter_decoder", type=str, default="mean") +parser.add_argument("--opt", type=str, default="adam") +args = parser.parse_args() + +print("Loading graph data..") +graph, feature_modules, node_maps = load_graph(args.data_dir, args.embed_dim) +if args.cuda: + graph.features = cudify(feature_modules, node_maps) +out_dims = {mode:args.embed_dim for mode in graph.relations} + +print("Loading edge data..") +train_queries = load_queries_by_formula(args.data_dir + "/train_edges.pkl") +val_queries = load_test_queries_by_formula(args.data_dir + "/val_edges.pkl") +test_queries = load_test_queries_by_formula(args.data_dir + "/test_edges.pkl") + +print("Loading query data..") +for i in range(2,4): + train_queries.update(load_queries_by_formula(args.data_dir + "/train_queries_{:d}.pkl".format(i))) + i_val_queries = load_test_queries_by_formula(args.data_dir + "/val_queries_{:d}.pkl".format(i)) + val_queries["one_neg"].update(i_val_queries["one_neg"]) + val_queries["full_neg"].update(i_val_queries["full_neg"]) + i_test_queries = load_test_queries_by_formula(args.data_dir + "/test_queries_{:d}.pkl".format(i)) + test_queries["one_neg"].update(i_test_queries["one_neg"]) + test_queries["full_neg"].update(i_test_queries["full_neg"]) + + +enc = get_encoder(args.depth, graph, out_dims, feature_modules, args.cuda) +dec = get_metapath_decoder(graph, enc.out_dims if args.depth > 0 else out_dims, args.decoder) +inter_dec = get_intersection_decoder(graph, out_dims, args.inter_decoder) + +enc_dec = QueryEncoderDecoder(graph, enc, dec, inter_dec) +if args.cuda: + enc_dec.cuda() + +if args.opt == "sgd": + optimizer = optim.SGD(filter(lambda p : p.requires_grad, enc_dec.parameters()), lr=args.lr, momentum=0) +elif args.opt == "adam": + optimizer = optim.Adam(filter(lambda p : p.requires_grad, enc_dec.parameters()), lr=args.lr) + +log_file = args.log_dir + "/{data:s}-{depth:d}-{embed_dim:d}-{lr:f}-{decoder:s}-{inter_decoder:s}.log".format( + data=args.data_dir.strip().split("/")[-1], + depth=args.depth, + embed_dim=args.embed_dim, + lr=args.lr, + decoder=args.decoder, + inter_decoder=args.inter_decoder) +model_file = args.model_dir + "/{data:s}-{depth:d}-{embed_dim:d}-{lr:f}-{decoder:s}-{inter_decoder:s}.log".format( + data=args.data_dir.strip().split("/")[-1], + depth=args.depth, + embed_dim=args.embed_dim, + lr=args.lr, + decoder=args.decoder, + inter_decoder=args.inter_decoder) +logger = setup_logging(log_file) + +run_train(enc_dec, optimizer, train_queries, val_queries, test_queries, logger, max_burn_in=args.max_burn_in, val_every=args.val_every, model_file=model_file) +torch.save(enc_dec.state_dict(), model_file) diff --git a/train_helpers.py b/train_helpers.py new file mode 100644 index 0000000..690a710 --- /dev/null +++ b/train_helpers.py @@ -0,0 +1,107 @@ +import numpy as np +from netquery.utils import eval_auc_queries, eval_perc_queries +import torch + +def check_conv(vals, window=2, tol=1e-6): + if len(vals) < 2 * window: + return False + conv = np.mean(vals[-window:]) - np.mean(vals[-2*window:-window]) + return conv < tol + +def update_loss(loss, losses, ema_loss, ema_alpha=0.01): + losses.append(loss) + if ema_loss is None: + ema_loss = loss + else: + ema_loss = (1-ema_alpha)*ema_loss + ema_alpha*loss + return losses, ema_loss + +def run_eval(model, queries, iteration, logger, by_type=False): + vals = {} + def _print_by_rel(rel_aucs, logger): + for rels, auc in rel_aucs.items(): + logger.info(str(rels) + "\t" + str(auc)) + for query_type in queries["one_neg"]: + auc, rel_aucs = eval_auc_queries(queries["one_neg"][query_type], model) + perc = eval_perc_queries(queries["full_neg"][query_type], model) + vals[query_type] = auc + logger.info("{:s} val AUC: {:f} val perc {:f}; iteration: {:d}".format(query_type, auc, perc, iteration)) + if by_type: + _print_by_rel(rel_aucs, logger) + if "inter" in query_type: + auc, rel_aucs = eval_auc_queries(queries["one_neg"][query_type], model, hard_negatives=True) + perc = eval_perc_queries(queries["full_neg"][query_type], model, hard_negatives=True) + logger.info("Hard-{:s} val AUC: {:f} val perc {:f}; iteration: {:d}".format(query_type, auc, perc, iteration)) + if by_type: + _print_by_rel(rel_aucs, logger) + vals[query_type + "hard"] = auc + return vals + +def run_train(model, optimizer, train_queries, val_queries, test_queries, logger, + max_burn_in =100000, batch_size=512, log_every=100, val_every=1000, tol=1e-6, + max_iter=int(10e7), inter_weight=0.005, path_weight=0.01, model_file=None): + edge_conv = False + ema_loss = None + vals = [] + losses = [] + conv_test = None + for i in range(max_iter): + + optimizer.zero_grad() + loss = run_batch(train_queries["1-chain"], model, i, batch_size) + if not edge_conv and (check_conv(vals) or len(losses) >= max_burn_in): + logger.info("Edge converged at iteration {:d}".format(i-1)) + logger.info("Testing at edge conv...") + conv_test = run_eval(model, test_queries, i, logger) + conv_test = np.mean(conv_test.values()) + edge_conv = True + losses = [] + ema_loss = None + vals = [] + if not model_file is None: + torch.save(model.state_dict(), model_file+"-edge_conv") + + if edge_conv: + for query_type in train_queries: + if query_type == "1-chain": + continue + if "inter" in query_type: + loss += inter_weight*run_batch(train_queries[query_type], model, i, batch_size) + loss += inter_weight*run_batch(train_queries[query_type], model, i, batch_size, hard_negatives=True) + else: + loss += path_weight*run_batch(train_queries[query_type], model, i, batch_size) + if check_conv(vals): + logger.info("Fully converged at iteration {:d}".format(i)) + break + + losses, ema_loss = update_loss(loss.data[0], losses, ema_loss) + loss.backward() + optimizer.step() + + if i % log_every == 0: + logger.info("Iter: {:d}; ema_loss: {:f}".format(i, ema_loss)) + + if i >= val_every and i % val_every == 0: + v = run_eval(model, val_queries, i, logger) + if edge_conv: + vals.append(np.mean(v.values())) + else: + vals.append(v["1-chain"]) + + v = run_eval(model, test_queries, i, logger) + logger.info("Test macro-averaged val: {:f}".format(np.mean(v.values()))) + logger.info("Improvement from edge conv: {:f}".format((np.mean(v.values())-conv_test)/conv_test)) + +def run_batch(train_queries, enc_dec, iter_count, batch_size, hard_negatives=False): + num_queries = [float(len(queries)) for queries in train_queries.values()] + denom = float(sum(num_queries)) + formula_index = np.argmax(np.random.multinomial(1, + np.array(num_queries)/denom)) + formula = train_queries.keys()[formula_index] + n = len(train_queries[formula]) + start = (iter_count * batch_size) % n + end = min(((iter_count+1) * batch_size) % n, n) + end = n if end <= start else end + queries = train_queries[formula][start:end] + loss = enc_dec.margin_loss(formula, queries, hard_negatives=hard_negatives) + return loss diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..e652efd --- /dev/null +++ b/utils.py @@ -0,0 +1,167 @@ +import numpy as np +import scipy +import scipy.stats as stats +import torch +from sklearn.metrics import roc_auc_score +from netquery.decoders import BilinearMetapathDecoder, TransEMetapathDecoder, BilinearDiagMetapathDecoder, SetIntersection, SimpleSetIntersection +from netquery.encoders import DirectEncoder, Encoder +from netquery.aggregators import MeanAggregator +#import cPickle as pickle +import pickle +import logging +import random + +""" +Misc utility functions.. +""" + +def cudify(feature_modules, node_maps=None): + if node_maps is None: + features = lambda nodes, mode : feature_modules[mode]( + torch.autograd.Variable(torch.LongTensor(nodes)+1).cuda()) + else: + features = lambda nodes, mode : feature_modules[mode]( + torch.autograd.Variable(torch.LongTensor([node_maps[mode][n] for n in nodes])+1).cuda()) + return features + +def _get_perc_scores(scores, lengths): + perc_scores = [] + cum_sum = 0 + neg_scores = scores[len(lengths):] + for i, length in enumerate(lengths): + perc_scores.append(stats.percentileofscore(neg_scores[cum_sum:cum_sum+length], scores[i])) + cum_sum += length + return perc_scores + +def eval_auc_queries(test_queries, enc_dec, batch_size=1000, hard_negatives=False, seed=0): + predictions = [] + labels = [] + formula_aucs = {} + random.seed(seed) + for formula in test_queries: + formula_labels = [] + formula_predictions = [] + formula_queries = test_queries[formula] + offset = 0 + while offset < len(formula_queries): + max_index = min(offset+batch_size, len(formula_queries)) + batch_queries = formula_queries[offset:max_index] + if hard_negatives: + lengths = [1 for j in range(offset, max_index)] + negatives = [random.choice(formula_queries[j].hard_neg_samples) for j in range(offset, max_index)] + else: + lengths = [1 for j in range(offset, max_index)] + negatives = [random.choice(formula_queries[j].neg_samples) for j in range(offset, max_index)] + offset += batch_size + + formula_labels.extend([1 for _ in range(len(lengths))]) + batch_scores = enc_dec.forward(formula, + batch_queries+[b for i, b in enumerate(batch_queries) for _ in range(lengths[i])], + [q.target_node for q in batch_queries] + negatives) + batch_scores = batch_scores.data.tolist() + formula_predictions.extend(batch_scores) + formula_aucs[formula] = roc_auc_score(formula_labels, np.nan_to_num(formula_predictions)) + labels.extend(formula_labels) + predictions.extend(formula_predictions) + overall_auc = roc_auc_score(labels, np.nan_to_num(predictions)) + return overall_auc, formula_aucs + + +def eval_perc_queries(test_queries, enc_dec, batch_size=1000, hard_negatives=False): + perc_scores = [] + for formula in test_queries: + formula_queries = test_queries[formula] + offset = 0 + while offset < len(formula_queries): + max_index = min(offset+batch_size, len(formula_queries)) + batch_queries = formula_queries[offset:max_index] + if hard_negatives: + lengths = [len(formula_queries[j].hard_neg_samples) for j in range(offset, max_index)] + negatives = [n for j in range(offset, max_index) for n in formula_queries[j].hard_neg_samples] + else: + lengths = [len(formula_queries[j].neg_samples) for j in range(offset, max_index)] + negatives = [n for j in range(offset, max_index) for n in formula_queries[j].neg_samples] + offset += batch_size + + batch_scores = enc_dec.forward(formula, + batch_queries+[b for i, b in enumerate(batch_queries) for _ in range(lengths[i])], + [q.target_node for q in batch_queries] + negatives) + batch_scores = batch_scores.data.tolist() + perc_scores.extend(_get_perc_scores(batch_scores, lengths)) + return np.mean(perc_scores) + +def get_encoder(depth, graph, out_dims, feature_modules, cuda): + if depth < 0 or depth > 3: + raise Exception("Depth must be between 0 and 3 (inclusive)") + + if depth == 0: + enc = DirectEncoder(graph.features, feature_modules) + else: + aggregator1 = MeanAggregator(graph.features) + enc1 = Encoder(graph.features, + graph.feature_dims, + out_dims, + graph.relations, + graph.adj_lists, feature_modules=feature_modules, + cuda=cuda, aggregator=aggregator1) + enc = enc1 + if depth >= 2: + aggregator2 = MeanAggregator(lambda nodes, mode : enc1(nodes, mode).t().squeeze()) + enc2 = Encoder(lambda nodes, mode : enc1(nodes, mode).t().squeeze(), + enc1.out_dims, + out_dims, + graph.relations, + graph.adj_lists, base_model=enc1, + cuda=cuda, aggregator=aggregator2) + enc = enc2 + if depth >= 3: + aggregator3 = MeanAggregator(lambda nodes, mode : enc2(nodes, mode).t().squeeze()) + enc3 = Encoder(lambda nodes, mode : enc1(nodes, mode).t().squeeze(), + enc2.out_dims, + out_dims, + graph.relations, + graph.adj_lists, base_model=enc2, + cuda=cuda, aggregator=aggregator3) + enc = enc3 + return enc + +def get_metapath_decoder(graph, out_dims, decoder): + if decoder == "bilinear": + dec = BilinearMetapathDecoder(graph.relations, out_dims) + elif decoder == "transe": + dec = TransEMetapathDecoder(graph.relations, out_dims) + elif decoder == "bilinear-diag": + dec = BilinearDiagMetapathDecoder(graph.relations, out_dims) + else: + raise Exception("Metapath decoder not recognized.") + return dec + +def get_intersection_decoder(graph, out_dims, decoder): + if decoder == "mean": + dec = SetIntersection(out_dims, out_dims, agg_func=torch.mean) + elif decoder == "mean-simple": + dec = SimpleSetIntersection(agg_func=torch.mean) + elif decoder == "min": + dec = SetIntersection(out_dims, out_dims, agg_func=torch.min) + elif decoder == "min-simple": + dec = SimpleSetIntersection(agg_func=torch.min) + else: + raise Exception("Intersection decoder not recognized.") + return dec + +def setup_logging(log_file, console=True): + logging.basicConfig(level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + filename=log_file, + filemode='w') + if console: + console = logging.StreamHandler() + # optional, set the logging level + console.setLevel(logging.INFO) + # set a format which is the same for console use + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + # tell the handler to use this format + console.setFormatter(formatter) + # add the handler to the root logger + logging.getLogger('').addHandler(console) + return logging From 8e9c264737bc3cf0ce116f02d28552afd8032bcc Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:53:42 -0500 Subject: [PATCH 2/9] delete after upload in wrong folder --- aggregators.py | 200 ------------------------------------------------- 1 file changed, 200 deletions(-) delete mode 100644 aggregators.py diff --git a/aggregators.py b/aggregators.py deleted file mode 100644 index 0052886..0000000 --- a/aggregators.py +++ /dev/null @@ -1,200 +0,0 @@ -import torch -import torch.nn as nn -import itertools -from torch.nn import init -from torch.autograd import Variable -import torch.nn.functional as F - -import random -import math -import numpy as np - -""" -Set of modules for aggregating embeddings of neighbors. -These modules take as input embeddings of neighbors. -""" - -class MeanAggregator(nn.Module): - """ - Aggregates a node's embeddings using mean of neighbors' embeddings - """ - def __init__(self, features, cuda=False): - """ - Initializes the aggregator for a specific graph. - - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. - """ - - super(MeanAggregator, self).__init__() - - self.features = features - self.cuda = cuda - - def forward(self, to_neighs, rel, keep_prob=0.5, max_keep=10): - """ - Aggregates embeddings for a batch of nodes. - keep_prob and max_keep are the parameters for edge/neighbour dropout. - - to_neighs -- list of neighbors of nodes - keep_prob -- probability of keeping a neighbor - max_keep -- maximum number of neighbors kept per node - """ - - # Local pointers to functions (speed hack) - _int = int - _set = set - _min = min - _len = len - _ceil = math.ceil - _sample = random.sample - samp_neighs = [_set(_sample(to_neigh, - _min(_int(_ceil(_len(to_neigh)*keep_prob)), max_keep) - )) for to_neigh in to_neighs] - unique_nodes_list = list(set.union(*samp_neighs)) - unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)} - mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) - column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] - row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))] - mask[row_indices, column_indices] = 1 - if self.cuda: - mask = mask.cuda() - num_neigh = mask.sum(1, keepdim=True) - mask = mask.div(num_neigh) - embed_matrix = self.features(unique_nodes_list, rel[-1]) - if len(embed_matrix.size()) == 1: - embed_matrix = embed_matrix.unsqueeze(dim=0) - to_feats = mask.mm(embed_matrix) - return to_feats - -class FastMeanAggregator(nn.Module): - """ - Aggregates a node's embeddings using mean of neighbors' embeddings - """ - def __init__(self, features, cuda=False): - """ - Initializes the aggregator for a specific graph. - - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. - """ - - super(FastMeanAggregator, self).__init__() - - self.features = features - self.cuda = cuda - - def forward(self, to_neighs, rel, keep_prob=None, max_keep=25): - """ - Aggregates embeddings for a batch of nodes. - keep_prob and max_keep are the parameters for edge/neighbour dropout. - - to_neighs -- list of neighbors of nodes - keep_prob -- probability of keeping a neighbor - max_keep -- maximum number of neighbors kept per node - """ - _random = random.random - _int = int - _len = len - samp_neighs = [to_neigh[_int(_random()*_len(to_neigh))] for i in itertools.repeat(None, max_keep) - for to_neigh in to_neighs] - embed_matrix = self.features(samp_neighs, rel[-1]) - to_feats = embed_matrix.view(max_keep, len(to_neighs), embed_matrix.size()[1]) - return to_feats.mean(dim=0) - -class PoolAggregator(nn.Module): - """ - Aggregates a node's embeddings using mean pooling of neighbors' embeddings - """ - def __init__(self, features, feature_dims, cuda=False): - """ - Initializes the aggregator for a specific graph. - - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. - """ - - super(PoolAggregator, self).__init__() - - self.features = features - self.feat_dims = feature_dims - self.pool_matrix = {} - for mode, feat_dim in self.feat_dims.items(): - self.pool_matrix[mode] = nn.Parameter(torch.FloatTensor(feat_dim, feat_dim)) - init.xavier_uniform(self.pool_matrix[mode]) - self.register_parameter(mode+"_pool", self.pool_matrix[mode]) - self.cuda = cuda - - def forward(self, to_neighs, rel, keep_prob=0.5, max_keep=10): - """ - Aggregates embeddings for a batch of nodes. - keep_prob and max_keep are the parameters for edge/neighbour dropout. - - to_neighs -- list of neighbors of nodes - keep_prob -- probability of keeping a neighbor - max_keep -- maximum number of neighbors kept per node - """ - _int = int - _set = set - _min = min - _len = len - _ceil = math.ceil - _sample = random.sample - samp_neighs = [_set(_sample(to_neigh, - _min(_int(_ceil(_len(to_neigh)*keep_prob)), max_keep) - )) for to_neigh in to_neighs] - unique_nodes_list = list(set.union(*samp_neighs)) - unique_nodes = {n:i for i,n in enumerate(unique_nodes_list)} - mask = Variable(torch.zeros(len(samp_neighs), len(unique_nodes))) - column_indices = [unique_nodes[n] for samp_neigh in samp_neighs for n in samp_neigh] - row_indices = [i for i in range(len(samp_neighs)) for j in range(len(samp_neighs[i]))] - mask[row_indices, column_indices] = 1 - mode = rel[0] - if self.cuda: - mask = mask.cuda() - embed_matrix = self.features(unique_nodes, rel[-1]).mm(self.pool_matrix[mode]) - to_feats = F.relu(mask.mm(embed_matrix)) - return to_feats - -class FastPoolAggregator(nn.Module): - """ - Aggregates a node's embeddings using mean pooling of neighbors' embeddings - """ - def __init__(self, features, feature_dims, - cuda=False): - """ - Initializes the aggregator for a specific graph. - - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. - """ - - super(FastPoolAggregator, self).__init__() - - self.features = features - self.feat_dims = feature_dims - self.pool_matrix = {} - for mode, feat_dim in self.feat_dims.items(): - self.pool_matrix[mode] = nn.Parameter(torch.FloatTensor(feat_dim, feat_dim)) - init.xavier_uniform(self.pool_matrix[mode]) - self.register_parameter(mode+"_pool", self.pool_matrix[mode]) - self.cuda = cuda - - def forward(self, to_neighs, rel, keep_prob=0.5, max_keep=10): - """ - Aggregates embeddings for a batch of nodes. - keep_prob and max_keep are the parameters for edge/neighbour dropout. - - to_neighs -- list of neighbors of nodes - keep_prob -- probability of keeping a neighbor - max_keep -- maximum number of neighbors kept per node - """ - _random = random.random - _int = int - _len = len - samp_neighs = [to_neigh[_int(_random()*_len(to_neigh))] for i in itertools.repeat(None, max_keep) - for to_neigh in to_neighs] - mode = rel[0] - embed_matrix = self.features(samp_neighs, rel[-1]).mm(self.pool_matrix[mode]) - to_feats = embed_matrix.view(max_keep, len(to_neighs), embed_matrix.size()[1]) - return to_feats.mean(dim=0) From 595a04da2e15e252977dd40ec7cdb63589a362fd Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:54:05 -0500 Subject: [PATCH 3/9] delete after upload in wrong folder --- data_utils.py | 110 -------------------------------------------------- 1 file changed, 110 deletions(-) delete mode 100644 data_utils.py diff --git a/data_utils.py b/data_utils.py deleted file mode 100644 index 8ef54ad..0000000 --- a/data_utils.py +++ /dev/null @@ -1,110 +0,0 @@ -#import cPickle as pickle -import pickle -import torch -from collections import OrderedDict, defaultdict -from multiprocessing import Process -import random -import json -from netquery.data_utils import parallel_sample, load_queries_by_type, sample_clean_test - -from netquery.graph import Graph, Query, _reverse_edge - -def load_graph(data_dir, embed_dim): - rels, adj_lists, node_maps = pickle.load(open(data_dir+"/graph_data.pkl", "rb")) - node_maps = {m : {n : i for i, n in enumerate(id_list)} for m, id_list in node_maps.items()} - for m in node_maps: - node_maps[m][-1] = -1 - feature_dims = {m : embed_dim for m in rels} - feature_modules = {m : torch.nn.Embedding(len(node_maps[m])+1, embed_dim) for m in rels} - for mode in rels: - feature_modules[mode].weight.data.normal_(0, 1./embed_dim) - features = lambda nodes, mode : feature_modules[mode]( - torch.autograd.Variable(torch.LongTensor([node_maps[mode][n] for n in nodes])+1)) - graph = Graph(features, feature_dims, rels, adj_lists) - return graph, feature_modules, node_maps - -def sample_new_clean(data_dir): - graph_loader = lambda : load_graph(data_dir, 10)[0] - sample_clean_test(graph_loader, data_dir) - -def clean_test(): - test_edges = pickle.load(open("/dfs/scratch0/nqe-bio/test_edges.pkl", "rb")) - val_edges = pickle.load(open("/dfs/scratch0/nqe-bio/val_edges.pkl", "rb")) - deleted_edges = set([q[0][1] for q in test_edges] + [_reverse_edge(q[0][1]) for q in test_edges] + - [q[0][1] for q in val_edges] + [_reverse_edge(q[0][1]) for q in val_edges]) - - for i in range(2,4): - for kind in ["val", "test"]: - if kind == "val": - to_keep = 1000 - else: - to_keep = 10000 - test_queries = load_queries_by_type("/dfs/scratch0/nqe-bio/{:s}_queries_{:d}-split.pkl".format(kind, i), keep_graph=True) - print("Loaded", i, kind) - for query_type in test_queries: - test_queries[query_type] = [q for q in test_queries[query_type] if len(q.get_edges().intersection(deleted_edges)) > 0] - test_queries[query_type] = test_queries[query_type][:to_keep] - test_queries = [q.serialize() for queries in test_queries.values() for q in queries] - pickle.dump(test_queries, open("/dfs/scratch0/nqe-bio/{:s}_queries_{:d}-clean.pkl".format(kind, i), "wb"), protocol=pickle.HIGHEST_PROTOCOL) - print("Finished", i, kind) - - - -def make_train_test_edge_data(data_dir): - print("Loading graph...") - graph, _, _ = load_graph(data_dir, 10) - print("Getting all edges...") - edges = graph.get_all_edges() - split_point = int(0.1*len(edges)) - val_test_edges = edges[:split_point] - print("Getting negative samples...") - val_test_edge_negsamples = [graph.get_negative_edge_samples(e, 100) for e in val_test_edges] - print("Making and storing test queries.") - val_test_edge_queries = [Query(("1-chain", val_test_edges[i]), val_test_edge_negsamples[i], None, 100) for i in range(split_point)] - val_split_point = int(0.1*len(val_test_edge_queries)) - val_queries = val_test_edge_queries[:val_split_point] - test_queries = val_test_edge_queries[val_split_point:] - pickle.dump([q.serialize() for q in val_queries], open(data_dir+"/val_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) - pickle.dump([q.serialize() for q in test_queries], open(data_dir+"/test_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) - - print("Removing test edges...") - graph.remove_edges(val_test_edges) - print("Making and storing train queries.") - train_edges = graph.get_all_edges() - train_queries = [Query(("1-chain", e), None, None) for e in train_edges] - pickle.dump([q.serialize() for q in train_queries], open(data_dir+"/train_edges.pkl", "w"), protocol=pickle.HIGHEST_PROTOCOL) - -def _discard_negatives(file_name, small_prop=0.9): - queries = pickle.load(open(file_name, "rb")) -# queries = [q if random.random() > small_prop else (q[0],[random.choice(tuple(q[1]))], None if q[2] is None else [random.choice(tuple(q[2]))]) for q in queries] - queries = [q if random.random() > small_prop else (q[0],[random.choice(list(q[1]))], None if q[2] is None else [random.choice(list(q[2]))]) for q in queries] - pickle.dump(queries, open(file_name.split(".")[0] + "-split.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - print("Finished", file_name) - -def discard_negatives(data_dir): - _discard_negatives(data_dir + "/val_edges.pkl") - _discard_negatives(data_dir + "/test_edges.pkl") - for i in range(2,4): - _discard_negatives(data_dir + "/val_queries_{:d}.pkl".format(i)) - _discard_negatives(data_dir + "/test_queries_{:d}.pkl".format(i)) - - -def make_train_test_query_data(data_dir): - graph, _, _ = load_graph(data_dir, 10) - queries_2, queries_3 = parallel_sample(graph, 20, 50000, data_dir, test=False) - t_queries_2, t_queries_3 = parallel_sample(graph, 20, 5000, data_dir, test=True) - t_queries_2 = list(set(t_queries_2) - set(queries_2)) - t_queries_3 = list(set(t_queries_3) - set(queries_3)) - pickle.dump([q.serialize() for q in queries_2], open(data_dir + "/train_queries_2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - pickle.dump([q.serialize() for q in queries_3], open(data_dir + "/train_queries_3.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - pickle.dump([q.serialize() for q in t_queries_2[10000:]], open(data_dir + "/test_queries_2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - pickle.dump([q.serialize() for q in t_queries_3[10000:]], open(data_dir + "/test_queries_3.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - pickle.dump([q.serialize() for q in t_queries_2[:10000]], open(data_dir + "/val_queries_2.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - pickle.dump([q.serialize() for q in t_queries_3[:10000]], open(data_dir + "/val_queries_3.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL) - - -if __name__ == "__main__": - #make_train_test_query_data("/dfs/scratch0/nqe-bio/") - #make_train_test_edge_data("/dfs/scratch0/nqe-bio/") - sample_new_clean("/dfs/scratch0/nqe-bio/") - #clean_test() From fda29059c7d8d7d418e406bf89bab8bdd228201e Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:54:16 -0500 Subject: [PATCH 4/9] delete after upload in wrong folder --- encoders.py | 146 ---------------------------------------------------- 1 file changed, 146 deletions(-) delete mode 100644 encoders.py diff --git a/encoders.py b/encoders.py deleted file mode 100644 index 772a174..0000000 --- a/encoders.py +++ /dev/null @@ -1,146 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import init -import torch.nn.functional as F - -""" -Set of modules for encoding nodes. -These modules take as input node ids and output embeddings. -""" - -class DirectEncoder(nn.Module): - """ - Encodes a node as a embedding via direct lookup. - (i.e., this is just like basic node2vec or matrix factorization) - """ - def __init__(self, features, feature_modules): - """ - Initializes the model for a specific graph. - - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. - feature_modules -- This should be a map from mode -> torch.nn.EmbeddingBag - """ - super(DirectEncoder, self).__init__() - for name, module in feature_modules.items(): - self.add_module("feat-"+name, module) - self.features = features - - def forward(self, nodes, mode, offset=None, **kwargs): - """ - Generates embeddings for a batch of nodes. - - nodes -- list of nodes - mode -- string desiginating the mode of the nodes - offsets -- specifies how the embeddings are aggregated. - see torch.nn.EmbeddingBag for format. - No aggregation if offsets is None - """ - - if offset is None: - embeds = self.features(nodes, mode).t() - norm = embeds.norm(p=2, dim=0, keepdim=True) - return embeds.div(norm.expand_as(embeds)) - else: - return self.features(nodes, mode, offset).t() - -class Encoder(nn.Module): - """ - Encodes a node's using a GCN/GraphSage approach - """ - def __init__(self, features, feature_dims, - out_dims, relations, adj_lists, aggregator, - base_model=None, cuda=False, - layer_norm=False, - feature_modules={}): - """ - Initializes the model for a specific graph. - - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. - feature_dims -- output dimension of each of the feature functions. - out_dims -- embedding dimensions for each mode (i.e., output dimensions) - relations -- map from mode -> out_going_relations - adj_lists -- map from relation_tuple -> node -> list of node's neighbors - base_model -- if features are from another encoder, pass it here for training - cuda -- whether or not to move params to the GPU - feature_modules -- if features come from torch.nn module, pass the modules here for training - """ - - super(Encoder, self).__init__() - - self.features = features - self.feat_dims = feature_dims - self.adj_lists = adj_lists - self.relations = relations - self.aggregator = aggregator - for name, module in feature_modules.items(): - self.add_module("feat-"+name, module) - if base_model != None: - self.base_model = base_model - - self.out_dims = out_dims - self.cuda = cuda - self.aggregator.cuda = cuda - self.layer_norm = layer_norm - self.compress_dims = {} - for source_mode in relations: - self.compress_dims[source_mode] = self.feat_dims[source_mode] - for (to_mode, _) in relations[source_mode]: - self.compress_dims[source_mode] += self.feat_dims[to_mode] - - self.self_params = {} - self.compress_params = {} - self.lns = {} - for mode, feat_dim in self.feat_dims.items(): - if self.layer_norm: - self.lns[mode] = LayerNorm(out_dims[mode]) - self.add_module(mode+"_ln", self.lns[mode]) - self.compress_params[mode] = nn.Parameter( - torch.FloatTensor(out_dims[mode], self.compress_dims[mode])) - init.xavier_uniform(self.compress_params[mode]) - self.register_parameter(mode+"_compress", self.compress_params[mode]) - - def forward(self, nodes, mode, keep_prob=0.5, max_keep=10): - """ - Generates embeddings for a batch of nodes. - - nodes -- list of nodes - mode -- string desiginating the mode of the nodes - """ - self_feat = self.features(nodes, mode).t() - neigh_feats = [] - for to_r in self.relations[mode]: - rel = (mode, to_r[1], to_r[0]) - to_neighs = [[-1] if node == -1 else self.adj_lists[rel][node] for node in nodes] - - # Special null neighbor for nodes with no edges of this type - to_neighs = [[-1] if len(l) == 0 else l for l in to_neighs] - to_feats = self.aggregator.forward(to_neighs, rel, keep_prob, max_keep) - to_feats = to_feats.t() - neigh_feats.append(to_feats) - - neigh_feats.append(self_feat) - combined = torch.cat(neigh_feats, dim=0) - combined = self.compress_params[mode].mm(combined) - if self.layer_norm: - combined = self.lns[mode](combined.t()).t() - combined = F.relu(combined) - return combined - - -class LayerNorm(nn.Module): - """ - Simple layer norm object optionally used with the convolutional encoder. - """ - - def __init__(self, feature_dim, eps=1e-6): - super(LayerNorm, self).__init__() - self.gamma = nn.Parameter(torch.ones((feature_dim,))) - self.beta = nn.Parameter(torch.zeros((feature_dim,))) - self.eps = eps - - def forward(self, x): - mean = x.mean(-1, keepdim=True) - std = x.std(-1, keepdim=True) - return self.gamma * (x - mean) / (std + self.eps) + self.beta From a7a384083cd4fb7656f2cf1a9b2560f86ee08b17 Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:54:29 -0500 Subject: [PATCH 5/9] delete after upload in wrong folder --- graph.py | 695 ------------------------------------------------------- 1 file changed, 695 deletions(-) delete mode 100644 graph.py diff --git a/graph.py b/graph.py deleted file mode 100644 index 8e3e4cc..0000000 --- a/graph.py +++ /dev/null @@ -1,695 +0,0 @@ -from collections import OrderedDict, defaultdict -import random - -def _reverse_relation(relation): - return (relation[-1], relation[1], relation[0]) - -def _reverse_edge(edge): - return (edge[-1], _reverse_relation(edge[1]), edge[0]) - - -class Formula(): - - def __init__(self, query_type, rels): - self.query_type = query_type - self.target_mode = rels[0][0] - self.rels = rels - if query_type == "1-chain" or query_type == "2-chain" or query_type == "3-chain": - self.anchor_modes = (rels[-1][-1],) - elif query_type == "2-inter" or query_type == "3-inter": - self.anchor_modes = tuple([rel[-1] for rel in rels]) - elif query_type == "3-inter_chain": - self.anchor_modes = (rels[0][-1], rels[1][-1][-1]) - elif query_type == "3-chain_inter": - self.anchor_modes = (rels[1][0][-1], rels[1][1][-1]) - - def __hash__(self): - return hash((self.query_type, self.rels)) - - def __eq__(self, other): - return ((self.query_type, self.rels)) == ((other.query_type, other.rels)) - - def __neq__(self, other): - return ((self.query_type, self.rels)) != ((other.query_type, other.rels)) - - def __str__(self): - return self.query_type + ": " + str(self.rels) - -class Query(): - - def __init__(self, query_graph, neg_samples, hard_neg_samples, neg_sample_max=100, keep_graph=False): - query_type = query_graph[0] - if query_type == "1-chain" or query_type == "2-chain" or query_type == "3-chain": - self.formula = Formula(query_type, tuple([query_graph[i][1] for i in range(1, len(query_graph))])) - self.anchor_nodes = (query_graph[-1][-1],) - elif query_type == "2-inter" or query_type == "3-inter": - self.formula = Formula(query_type, tuple([query_graph[i][1] for i in range(1, len(query_graph))])) - self.anchor_nodes = tuple([query_graph[i][-1] for i in range(1, len(query_graph))]) - elif query_type == "3-inter_chain": - self.formula = Formula(query_type, (query_graph[1][1], (query_graph[2][0][1], query_graph[2][1][1]))) - self.anchor_nodes = (query_graph[1][-1], query_graph[2][-1][-1]) - elif query_type == "3-chain_inter": - self.formula = Formula(query_type, (query_graph[1][1], (query_graph[2][0][1], query_graph[2][1][1]))) - self.anchor_nodes = (query_graph[2][0][-1], query_graph[2][1][-1]) - self.target_node = query_graph[1][0] - if keep_graph: - self.query_graph = query_graph - else: - self.query_graph = None - if not neg_samples is None: - self.neg_samples = list(neg_samples) if len(neg_samples) < neg_sample_max else random.sample(neg_samples, neg_sample_max) - else: - self.neg_samples = None - if not hard_neg_samples is None: - self.hard_neg_samples = list(hard_neg_samples) if len(hard_neg_samples) <= neg_sample_max else random.sample(hard_neg_samples, neg_sample_max) - else: - self.hard_neg_samples = None - - def contains_edge(self, edge): - if self.query_graph is None: - raise Exception("Can only test edge contain if graph is kept. Reinit with keep_graph=True") - edges = self.query_graph[1:] - if "inter_chain" in self.query_graph[0] or "chain_inter" in self.query_graph[0]: - edges = (edges[0], edges[1][0], edges[1][1]) - return edge in edges or (edge[1], _reverse_relation(edge[1]), edge[0]) in edges - - def get_edges(self): - if self.query_graph is None: - raise Exception("Can only test edge contain if graph is kept. Reinit with keep_graph=True") - edges = self.query_graph[1:] - if "inter_chain" in self.query_graph[0] or "chain_inter" in self.query_graph[0]: - edges = (edges[0], edges[1][0], edges[1][1]) - return set(edges).union(set([(e[-1], _reverse_relation(e[1]), e[0]) for e in edges])) - - def __hash__(self): - return hash((self.formula, self.target_node, self.anchor_nodes)) - - def __eq__(self, other): - return (self.formula, self.target_node, self.anchor_nodes) == (other.formula, other.target_node, other.anchor_nodes) - - def __neq__(self, other): - return self.__hash__() != other.__hash__() - - def serialize(self): - if self.query_graph is None: - raise Exception("Cannot serialize query loaded with query graph!") - return (self.query_graph, self.neg_samples, self.hard_neg_samples) - - @staticmethod - def deserialize(serial_info, keep_graph=False): - return Query(serial_info[0], serial_info[1], serial_info[2], None if serial_info[1] is None else len(serial_info[1]), keep_graph=keep_graph) - - - -class Graph(): - """ - Simple container for heteregeneous graph data. - """ - def __init__(self, features, feature_dims, relations, adj_lists): - self.features = features - self.feature_dims = feature_dims - self.relations = relations - self.adj_lists = adj_lists - self.full_sets = defaultdict(set) - self.full_lists = {} - self.meta_neighs = defaultdict(dict) - for rel, adjs in self.adj_lists.items(): - full_set = set(self.adj_lists[rel].keys()) - self.full_sets[rel[0]] = self.full_sets[rel[0]].union(full_set) - for mode, full_set in self.full_sets.items(): - self.full_lists[mode] = list(full_set) - self._cache_edge_counts() - self._make_flat_adj_lists() - - def _make_flat_adj_lists(self): - self.flat_adj_lists = defaultdict(lambda : defaultdict(list)) - for rel, adjs in self.adj_lists.items(): - for node, neighs in adjs.items(): - self.flat_adj_lists[rel[0]][node].extend([(rel, neigh) for neigh in neighs]) - - def _cache_edge_counts(self): - self.edges = 0. - self.rel_edges = {} - for r1 in self.relations: - for r2 in self.relations[r1]: - rel = (r1,r2[1], r2[0]) - self.rel_edges[rel] = 0. - for adj_list in self.adj_lists[rel].values(): - self.rel_edges[rel] += len(adj_list) - self.edges += 1. - self.rel_weights = OrderedDict() - self.mode_edges = defaultdict(float) - self.mode_weights = OrderedDict() - for rel, edge_count in self.rel_edges.items(): - self.rel_weights[rel] = edge_count / self.edges - self.mode_edges[rel[0]] += edge_count - for mode, edge_count in self.mode_edges.items(): - self.mode_weights[mode] = edge_count / self.edges - - def remove_edges(self, edge_list): - for edge in edge_list: - try: - self.adj_lists[edge[1]][edge[0]].remove(edge[-1]) - except Exception: - continue - - try: - self.adj_lists[_reverse_relation(edge[1])][edge[-1]].remove(edge[0]) - except Exception: - continue - self.meta_neighs = defaultdict(dict) - self._cache_edge_counts() - self._make_flat_adj_lists() - - def get_all_edges(self, seed=0, exclude_rels=set([])): - """ - Returns all edges in the form (node1, relation, node2) - """ - edges = [] - random.seed(seed) - for rel, adjs in self.adj_lists.items(): - if rel in exclude_rels: - continue - for node, neighs in adjs.items(): - edges.extend([(node, rel, neigh) for neigh in neighs if neigh != -1]) - random.shuffle(edges) - return edges - - def get_all_edges_byrel(self, seed=0, - exclude_rels=set([])): - random.seed(seed) - edges = defaultdict(list) - for rel, adjs in self.adj_lists.items(): - if rel in exclude_rels: - continue - for node, neighs in adjs.items(): - edges[(rel,)].extend([(node, neigh) for neigh in neighs if neigh != -1]) - - def get_negative_edge_samples(self, edge, num, rejection_sample=True): - if rejection_sample: - neg_nodes = set([]) - counter = 0 - while len(neg_nodes) < num: - neg_node = random.choice(self.full_lists[edge[1][0]]) - if not neg_node in self.adj_lists[_reverse_relation(edge[1])][edge[2]]: - neg_nodes.add(neg_node) - counter += 1 - if counter > 100*num: - return self.get_negative_edge_samples(edge, num, rejection_sample=False) - else: - neg_nodes = self.full_sets[edge[1][0]] - self.adj_lists[_reverse_relation(edge[1])][edge[2]] - neg_nodes = list(neg_nodes) if len(neg_nodes) <= num else random.sample(list(neg_nodes), num) - return neg_nodes - - def sample_test_queries(self, train_graph, q_types, samples_per_type, neg_sample_max, verbose=True): - queries = [] - for q_type in q_types: - sampled = 0 - while sampled < samples_per_type: - q = self.sample_query_subgraph_bytype(q_type) - if q is None or not train_graph._is_negative(q, q[1][0], False): - continue - negs, hard_negs = self.get_negative_samples(q) - if negs is None or ("inter" in q[0] and hard_negs is None): - continue - query = Query(q, negs, hard_negs, neg_sample_max=neg_sample_max, keep_graph=True) - queries.append(query) - sampled += 1 - if sampled % 1000 == 0 and verbose: - print("Sampled", sampled) - return queries - - def sample_queries(self, arity, num_samples, neg_sample_max, verbose=True): - sampled = 0 - queries = [] - while sampled < num_samples: - q = self.sample_query_subgraph(arity) - if q is None: - continue - negs, hard_negs = self.get_negative_samples(q) - if negs is None or ("inter" in q[0] and hard_negs is None): - continue - query = Query(q, negs, hard_negs, neg_sample_max=neg_sample_max, keep_graph=True) - queries.append(query) - sampled += 1 - if sampled % 1000 == 0 and verbose: - print("Sampled", sampled) - return queries - - - def get_negative_samples(self, query): - if query[0] == "3-chain" or query[0] == "2-chain": - edges = query[1:] - rels = [_reverse_relation(edge[1]) for edge in edges[::-1]] - meta_neighs = self.get_metapath_neighs(query[-1][-1], tuple(rels)) - negative_samples = self.full_sets[query[1][1][0]] - meta_neighs - if len(negative_samples) == 0: - return None, None - else: - return negative_samples, None - elif query[0] == "2-inter" or query[0] == "3-inter": - rel_1 = _reverse_relation(query[1][1]) - union_neighs = self.adj_lists[rel_1][query[1][-1]] - inter_neighs = self.adj_lists[rel_1][query[1][-1]] - for i in range(2,len(query)): - rel = _reverse_relation(query[i][1]) - union_neighs = union_neighs.union(self.adj_lists[rel][query[i][-1]]) - inter_neighs = inter_neighs.intersection(self.adj_lists[rel][query[i][-1]]) - neg_samples = self.full_sets[query[1][1][0]] - inter_neighs - hard_neg_samples = union_neighs - inter_neighs - if len(neg_samples) == 0 or len(hard_neg_samples) == 0: - return None, None - return neg_samples, hard_neg_samples - elif query[0] == "3-inter_chain": - rel_1 = _reverse_relation(query[1][1]) - union_neighs = self.adj_lists[rel_1][query[1][-1]] - inter_neighs = self.adj_lists[rel_1][query[1][-1]] - chain_rels = [_reverse_relation(edge[1]) for edge in query[2][::-1]] - chain_neighs = self.get_metapath_neighs(query[2][-1][-1], tuple(chain_rels)) - union_neighs = union_neighs.union(chain_neighs) - inter_neighs = inter_neighs.intersection(chain_neighs) - neg_samples = self.full_sets[query[1][1][0]] - inter_neighs - hard_neg_samples = union_neighs - inter_neighs - if len(neg_samples) == 0 or len(hard_neg_samples) == 0: - return None, None - return neg_samples, hard_neg_samples - elif query[0] == "3-chain_inter": - inter_rel_1 = _reverse_relation(query[-1][0][1]) - inter_neighs_1 = self.adj_lists[inter_rel_1][query[-1][0][-1]] - inter_rel_2 = _reverse_relation(query[-1][1][1]) - inter_neighs_2 = self.adj_lists[inter_rel_2][query[-1][1][-1]] - - inter_neighs = inter_neighs_1.intersection(inter_neighs_2) - union_neighs = inter_neighs_1.union(inter_neighs_2) - rel = _reverse_relation(query[1][1]) - pos_nodes = set([n for neigh in inter_neighs for n in self.adj_lists[rel][neigh]]) - union_pos_nodes = set([n for neigh in union_neighs for n in self.adj_lists[rel][neigh]]) - neg_samples = self.full_sets[query[1][1][0]] - pos_nodes - hard_neg_samples = union_pos_nodes - pos_nodes - if len(neg_samples) == 0 or len(hard_neg_samples) == 0: - return None, None - return neg_samples, hard_neg_samples - - def sample_edge(self, node, mode): - rel, neigh = random.choice(self.flat_adj_lists[mode][node]) - edge = (node, rel, neigh) - return edge - - def sample_query_subgraph_bytype(self, q_type, start_node=None): - if start_node is None: - start_rel = random.choice(self.adj_lists.keys()) - node = random.choice(self.adj_lists[start_rel].keys()) - mode = start_rel[0] - else: - node, mode = start_node - - if q_type[0] == "3": - if q_type == "3-chain" or q_type == "3-chain_inter": - num_edges = 1 - elif q_type == "3-inter_chain": - num_edges = 2 - elif q_type == "3-inter": - num_edges = 3 - if num_edges > len(self.flat_adj_lists[mode][node]): - return None - if num_edges == 1: - rel, neigh = random.choice(self.flat_adj_lists[mode][node]) - edge = (node, rel, neigh) - next_query = self.sample_query_subgraph_bytype( - "2-chain" if q_type == "3-chain" else "2-inter", start_node=(neigh, rel[0])) - if next_query is None: - return None - if next_query[0] == "2-chain": - return ("3-chain", edge, next_query[1], next_query[2]) - else: - return ("3-chain_inter", edge, (next_query[1], next_query[2])) - elif num_edges == 2: - rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) - edge_1 = (node, rel_1, neigh_1) - neigh_2 = neigh_1 - rel_2 = rel_1 - while (neigh_1, rel_1) == (neigh_2, rel_2): - rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) - edge_2 = (node, rel_2, neigh_2) - return ("3-inter_chain", edge_1, (edge_2, self.sample_edge(neigh_2, rel_2[-1]))) - elif num_edges == 3: - rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) - edge_1 = (node, rel_1, neigh_1) - neigh_2 = neigh_1 - rel_2 = rel_1 - while (rel_1, neigh_1) == (rel_2, neigh_2): - rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) - edge_2 = (node, rel_2, neigh_2) - neigh_3 = neigh_1 - rel_3 = rel_1 - while ((rel_1, neigh_1) == (rel_3, neigh_3)) or ((rel_2, neigh_2) == (rel_3, neigh_3)): - rel_3, neigh_3 = random.choice(self.flat_adj_lists[mode][node]) - edge_3 = (node, rel_3, neigh_3) - return ("3-inter", edge_1, edge_2, edge_3) - - if q_type[0] == "2": - num_edges = 1 if q_type == "2-chain" else 2 - if num_edges > len(self.flat_adj_lists[mode][node]): - return None - if num_edges == 1: - rel, neigh = random.choice(self.flat_adj_lists[mode][node]) - edge = (node, rel, neigh) - return ("2-chain", edge, self.sample_edge(neigh, rel[-1])) - elif num_edges == 2: - rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) - edge_1 = (node, rel_1, neigh_1) - neigh_2 = neigh_1 - rel_2 = rel_1 - while (neigh_1, rel_1) == (neigh_2, rel_2): - rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) - edge_2 = (node, rel_2, neigh_2) - return ("2-inter", edge_1, edge_2) - - - def sample_query_subgraph(self, arity, start_node=None): - if start_node is None: - start_rel = random.choice(self.adj_lists.keys()) - node = random.choice(self.adj_lists[start_rel].keys()) - mode = start_rel[0] - else: - node, mode = start_node - if arity > 3 or arity < 2: - raise Exception("Only arity of at most 3 is supported for queries") - - if arity == 3: - # 1/2 prob of 1 edge, 1/4 prob of 2, 1/4 prob of 3 - num_edges = random.choice([1,1,2,3]) - if num_edges > len(self.flat_adj_lists[mode][node]): - return None - if num_edges == 1: - rel, neigh = random.choice(self.flat_adj_lists[mode][node]) - edge = (node, rel, neigh) - next_query = self.sample_query_subgraph(2, start_node=(neigh, rel[0])) - if next_query is None: - return None - if next_query[0] == "2-chain": - return ("3-chain", edge, next_query[1], next_query[2]) - else: - return ("3-chain_inter", edge, (next_query[1], next_query[2])) - elif num_edges == 2: - rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) - edge_1 = (node, rel_1, neigh_1) - neigh_2 = neigh_1 - rel_2 = rel_1 - while (neigh_1, rel_1) == (neigh_2, rel_2): - rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) - edge_2 = (node, rel_2, neigh_2) - return ("3-inter_chain", edge_1, (edge_2, self.sample_edge(neigh_2, rel_2[-1]))) - elif num_edges == 3: - rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) - edge_1 = (node, rel_1, neigh_1) - neigh_2 = neigh_1 - rel_2 = rel_1 - while (rel_1, neigh_1) == (rel_2, neigh_2): - rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) - edge_2 = (node, rel_2, neigh_2) - neigh_3 = neigh_1 - rel_3 = rel_1 - while ((rel_1, neigh_1) == (rel_3, neigh_3)) or ((rel_2, neigh_2) == (rel_3, neigh_3)): - rel_3, neigh_3 = random.choice(self.flat_adj_lists[mode][node]) - edge_3 = (node, rel_3, neigh_3) - return ("3-inter", edge_1, edge_2, edge_3) - - if arity == 2: - num_edges = random.choice([1,2]) - if num_edges > len(self.flat_adj_lists[mode][node]): - return None - if num_edges == 1: - rel, neigh = random.choice(self.flat_adj_lists[mode][node]) - edge = (node, rel, neigh) - return ("2-chain", edge, self.sample_edge(neigh, rel[-1])) - elif num_edges == 2: - rel_1, neigh_1 = random.choice(self.flat_adj_lists[mode][node]) - edge_1 = (node, rel_1, neigh_1) - neigh_2 = neigh_1 - rel_2 = rel_1 - while (neigh_1, rel_1) == (neigh_2, rel_2): - rel_2, neigh_2 = random.choice(self.flat_adj_lists[mode][node]) - edge_2 = (node, rel_2, neigh_2) - return ("2-inter", edge_1, edge_2) - - def get_metapath_neighs(self, node, rels): - if node in self.meta_neighs[rels]: - return self.meta_neighs[rels][node] - current_set = [node] - for rel in rels: - current_set = set([neigh for n in current_set for neigh in self.adj_lists[rel][n]]) - self.meta_neighs[rels][node] = current_set - return current_set - - ## TESTING CODE - - def _check_edge(self, query, i): - return query[i][-1] in self.adj_lists[query[i][1]][query[i][0]] - - def _is_subgraph(self, query, verbose): - if query[0] == "3-chain": - for i in range(3): - if not self._check_edge(query, i+1): - raise Exception(str(query)) - if not (query[1][-1] == query[2][0] and query[2][-1] == query[3][0]): - raise Exception(str(query)) - if query[0] == "2-chain": - for i in range(2): - if not self._check_edge(query, i+1): - raise Exception(str(query)) - if not query[1][-1] == query[2][0]: - raise Exception(str(query)) - if query[0] == "2-inter": - for i in range(2): - if not self._check_edge(query, i+1): - raise Exception(str(query)) - if not query[1][0] == query[2][0]: - raise Exception(str(query)) - if query[0] == "3-inter": - for i in range(3): - if not self._check_edge(query, i+1): - raise Exception(str(query)) - if not (query[1][0] == query[2][0] and query[2][0] == query[3][0]): - raise Exception(str(query)) - if query[0] == "3-inter_chain": - if not (self._check_edge(query, 1) and self._check_edge(query[2], 0) and self._check_edge(query[2], 1)): - raise Exception(str(query)) - if not (query[1][0] == query[2][0][0] and query[2][0][-1] == query[2][1][0]): - raise Exception(str(query)) - if query[0] == "3-chain_inter": - if not (self._check_edge(query, 1) and self._check_edge(query[2], 0) and self._check_edge(query[2], 1)): - raise Exception(str(query)) - if not (query[1][-1] == query[2][0][0] and query[2][0][0] == query[2][1][0]): - raise Exception(str(query)) - return True - - def _is_negative(self, query, neg_node, is_hard): - if query[0] == "2-chain": - query = (query[0], (neg_node, query[1][1], query[1][2]), query[2]) - if query[2][-1] in self.get_metapath_neighs(query[1][0], (query[1][1], query[2][1])): - return False - if query[0] == "3-chain": - query = (query[0], (neg_node, query[1][1], query[1][2]), query[2], query[3]) - if query[3][-1] in self.get_metapath_neighs(query[1][0], (query[1][1], query[2][1], query[3][1])): - return False - if query[0] == "2-inter": - query = (query[0], (neg_node, query[1][1], query[1][2]), (neg_node, query[2][1], query[2][2])) - if not is_hard: - if self._check_edge(query, 1) and self._check_edge(query, 2): - return False - else: - if (self._check_edge(query, 1) and self._check_edge(query, 2)) or not (self._check_edge(query, 1) or self._check_edge(query, 2)): - return False - if query[0] == "3-inter": - query = (query[0], (neg_node, query[1][1], query[1][2]), (neg_node, query[2][1], query[2][2]), (neg_node, query[3][1], query[3][2])) - if not is_hard: - if self._check_edge(query, 1) and self._check_edge(query, 2) and self._check_edge(query, 3): - return False - else: - if (self._check_edge(query, 1) and self._check_edge(query, 2) and self._check_edge(query, 3))\ - or not (self._check_edge(query, 1) or self._check_edge(query, 2) or self._check_edge(query, 3)): - return False - if query[0] == "3-inter_chain": - query = (query[0], (neg_node, query[1][1], query[1][2]), ((neg_node, query[2][0][1], query[2][0][2]), query[2][1])) - meta_check = lambda : query[2][-1][-1] in self.get_metapath_neighs(query[1][0], (query[2][0][1], query[2][1][1])) - neigh_check = lambda : self._check_edge(query, 1) - if not is_hard: - if meta_check() and neigh_check(): - return False - else: - if (meta_check() and neigh_check()) or not (meta_check() or neigh_check()): - return False - if query[0] == "3-chain_inter": - query = (query[0], (neg_node, query[1][1], query[1][2]), query[2]) - target_neigh = self.adj_lists[query[1][1]][neg_node] - neigh_1 = self.adj_lists[_reverse_relation(query[2][0][1])][query[2][0][-1]] - neigh_2 = self.adj_lists[_reverse_relation(query[2][1][1])][query[2][1][-1]] - if not is_hard: - if target_neigh in neigh_1.intersection(neigh_2): - return False - else: - if target_neigh in neigh_1.intersection(neigh_2) and not target_neigh in neigh_1.union(neigh_2): - return False - return True - - - - def _run_test(self, num_samples=1000): - for i in range(num_samples): - q = self.sample_query_subgraph(2) - if q is None: - continue - self._is_subgraph(q, True) - negs, hard_negs = self.get_negative_samples(q) - if not negs is None: - for n in negs: - self._is_negative(q, n, False) - if not hard_negs is None: - for n in hard_negs: - self._is_negative(q, n, True) - q = self.sample_query_subgraph(3) - if q is None: - continue - self._is_subgraph(q, True) - negs, hard_negs = self.get_negative_samples(q) - if not negs is None: - for n in negs: - self._is_negative(q, n, False) - if not hard_negs is None: - for n in hard_negs: - self._is_negative(q, n, True) - return True - - - """ - TO DELETE? - def sample_chain_from_node(self, length, node, rel): - rels = [rel] - for cur_len in range(length-1): - next_rel = random.choice(self.relations[rels[-1][-1]]) - rels.append((rels[-1][-1], next_rel[-1], next_rel[0])) - - rels = tuple(rels) - meta_neighs = self.get_metapath_neighs(node, rels) - rev_rel = _reverse_relation(rels[-1]) - full_set = self.full_sets[rev_rel] - diff_set = full_set - meta_neighs - if len(meta_neighs) == 0 or len(diff_set) == 0: - return None, None, None - chain = (node, random.choice(list(meta_neighs))) - neg_chain = (node, random.choice(list(diff_set))) - return chain, neg_chain, rels - - def sample_chain(self, length, start_mode): - rel = random.choice(self.relations[start_mode]) - rel = (start_mode, rel[-1], rel[0]) - if len(self.adj_lists[rel]) == 0: - return None, None, None - node = random.choice(self.adj_lists[rel].keys()) - return self.sample_chain_from_node(length, node, rel) - - def sample_chains(self, length, anchor_weights, num_samples): - sampled = 0 - graph_chains = defaultdict(list) - neg_chains = defaultdict(list) - while sampled < num_samples: - anchor_mode = anchor_weights.keys()[np.argmax(np.random.multinomial(1, anchor_weights.values()))] - chain, neg_chain, rels = self.sample_chain(length, anchor_mode) - if chain is None: - continue - graph_chains[rels].append(chain) - neg_chains[rels].append(neg_chain) - sampled += 1 - return graph_chains, neg_chains - - - def sample_polytree_rootinter(self, length, target_mode, try_out=100): - num_chains = random.randint(2,length) - added = 0 - nodes = [] - rels_list = [] - - for i in range(num_chains): - remaining = length-added-num_chains - if i != num_chains - 1: - remaining = remaining if remaining == 0 else random.randint(0, remaining) - added += remaining - chain_len = 1 + remaining - if i == 0: - chain, _, rels = self.sample_chain(chain_len, target_mode) - try_count = 0 - while chain is None and try_count <= try_out: - chain, _, rels = self.sample_chain(chain_len, target_mode) - try_count += 1 - - if chain is None: - return None, None, None, None, None - target_node = chain[0] - nodes.append(chain[-1]) - rels_list.append(tuple([_reverse_relation(rel) for rel in rels[::-1]])) - else: - rel = random.choice([r for r in self.relations[target_mode] - if len(self.adj_lists[(target_mode, r[-1], r[0])][target_node]) > 0]) - rel = (target_mode, rel[-1], rel[0]) - chain, _, rels = self.sample_chain_from_node(chain_len, target_node, rel) - try_count = 0 - while chain is None and try_count <= try_out: - chain, _, rels = self.sample_chain_from_node(chain_len, target_node, rel) - if chain is None: - try_count += 1 - elif chain[-1] in nodes: - chain = None - if chain is None: - return None, None, None, None, None - nodes.append(chain[-1]) - rels_list.append(tuple([_reverse_relation(rel) for rel in rels[::-1]])) - - for i in range(len(nodes)): - meta_neighs = self.get_metapath_neighs(nodes[i], rels_list[i]) - if i == 0: - meta_neighs_inter = meta_neighs - meta_neighs_union = meta_neighs - else: - meta_neighs_inter = meta_neighs_inter.intersection(meta_neighs) - meta_neighs_union = meta_neighs_union.union(meta_neighs) - hard_neg_nodes = list(meta_neighs_union-meta_neighs_inter) - neg_nodes = list(self.full_sets[rels[0]]-meta_neighs_inter) - if len(neg_nodes) == 0: - return None, None, None, None, None - if len(hard_neg_nodes) == 0: - return None, None, None, None, None - - return target_node, neg_nodes, hard_neg_nodes, tuple(nodes), tuple(rels_list) - - - def sample_polytrees_parallel(self, length, thread_samples, threads, try_out=100): - pool = Pool(threads) - sample_func = partial(self.sample_polytree, length) - sizes = [thread_samples for _ in range(threads)] - results = pool.map(sample_func, sizes) - polytrees = {} - neg_polytrees = {} - hard_neg_polytrees = {} - for p, n, h in results: - polytrees.update(p) - neg_polytrees.update(n) - hard_neg_polytrees.updarte(h) - return polytrees, neg_polytrees, hard_neg_polytrees - - def sample_polytrees(self, length, num_samples, try_out=1): - samples = 0 - polytrees = defaultdict(list) - neg_polytrees = defaultdict(list) - hard_neg_polytrees = defaultdict(list) - while samples < num_samples: - t, n, h_n, nodes, rels = self.sample_polytree(length, random.choice(self.relations.keys())) - if t is None: - continue - samples += 1 - polytrees[rels].append((t, nodes)) - neg_polytrees[rels].append((n, nodes)) - hard_neg_polytrees[rels].append((h_n, nodes)) - return polytrees, neg_polytrees, hard_neg_polytrees - - """ From c4649ba1a92f4ed8e098d318f14ecdbfb981cd08 Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:54:41 -0500 Subject: [PATCH 6/9] delete after upload in wrong folder --- model.py | 189 ------------------------------------------------------- 1 file changed, 189 deletions(-) delete mode 100644 model.py diff --git a/model.py b/model.py deleted file mode 100644 index 216a0cb..0000000 --- a/model.py +++ /dev/null @@ -1,189 +0,0 @@ -import torch -import torch.nn as nn -import numpy as np - -import random -from netquery.graph import _reverse_relation - -EPS = 10e-6 - -""" -End-to-end autoencoder models for representation learning on -heteregenous graphs/networks -""" - -class MetapathEncoderDecoder(nn.Module): - """ - Encoder decoder model that reasons over metapaths - """ - - def __init__(self, graph, enc, dec): - """ - graph -- simple graph object; see graph.py - enc --- an encoder module that generates embeddings (see encoders.py) - dec --- an decoder module that predicts compositional relationships, i.e. metapaths, between nodes given embeddings. (see decoders.py) - Note that the decoder must be an *compositional/metapath* decoder (i.e., with name Metapath*.py) - """ - super(MetapathEncoderDecoder, self).__init__() - self.enc = enc - self.dec = dec - self.graph = graph - - def forward(self, nodes1, nodes2, rels): - """ - Returns a vector of 'relationship scores' for pairs of nodes being connected by the given metapath (sequence of relations). - Essentially, the returned scores are the predicted likelihood of the node pairs being connected - by the given metapath, where the pairs are given by the ordering in nodes1 and nodes2, - i.e. the first node id in nodes1 is paired with the first node id in nodes2. - """ - return self.dec.forward(self.enc.forward(nodes1, rels[0][0]), - self.enc.forward(nodes2, rels[-1][-1]), - rels) - - def margin_loss(self, nodes1, nodes2, rels): - """ - Standard max-margin based loss function. - Maximizes relationaship scores for true pairs vs negative samples. - """ - affs = self.forward(nodes1, nodes2, rels) - neg_nodes = [random.randint(1,len(self.graph.adj_lists[_reverse_relation[rels[-1]]])-1) for _ in range(len(nodes1))] - neg_affs = self.forward(nodes1, neg_nodes, - rels) - margin = 1 - (affs - neg_affs) - margin = torch.clamp(margin, min=0) - loss = margin.mean() - return loss - -class QueryEncoderDecoder(nn.Module): - """ - Encoder decoder model that reasons about edges, metapaths and intersections - """ - - def __init__(self, graph, enc, path_dec, inter_dec): - super(QueryEncoderDecoder, self).__init__() - self.enc = enc - self.path_dec = path_dec - self.inter_dec = inter_dec - self.graph = graph - self.cos = nn.CosineSimilarity(dim=0) - - def forward(self, formula, queries, source_nodes): - if formula.query_type == "1-chain" or formula.query_type == "2-chain" or formula.query_type == "3-chain": - # a chain is simply a call to the path decoder - return self.path_dec.forward( - self.enc.forward(source_nodes, formula.target_mode), - self.enc.forward([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]), - formula.rels) - elif formula.query_type == "2-inter" or formula.query_type == "3-inter" or formula.query_type == "3-inter_chain": - target_embeds = self.enc(source_nodes, formula.target_mode) - - embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]) - embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[0])) - - embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1]) - if len(formula.rels[1]) == 2: - for i_rel in formula.rels[1][::-1]: - embeds2 = self.path_dec.project(embeds2, _reverse_relation(i_rel)) - else: - embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1])) - - if formula.query_type == "3-inter": - embeds3 = self.enc([query.anchor_nodes[2] for query in queries], formula.anchor_modes[2]) - embeds3 = self.path_dec.project(embeds3, _reverse_relation(formula.rels[2])) - - query_intersection = self.inter_dec(embeds1, embeds2, formula.target_mode, embeds3) - else: - query_intersection = self.inter_dec(embeds1, embeds2, formula.target_mode) - scores = self.cos(target_embeds, query_intersection) - return scores - elif formula.query_type == "3-chain_inter": - target_embeds = self.enc(source_nodes, formula.target_mode) - - embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]) - embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[1][0])) - embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1]) - embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1][1])) - query_intersection = self.inter_dec(embeds1, embeds2, formula.rels[0][-1]) - query_intersection = self.path_dec.project(query_intersection, _reverse_relation(formula.rels[0])) - scores = self.cos(target_embeds, query_intersection) - return scores - - - def margin_loss(self, formula, queries, hard_negatives=False, margin=1): - if not "inter" in formula.query_type and hard_negatives: - raise Exception("Hard negative examples can only be used with intersection queries") - elif hard_negatives: - neg_nodes = [random.choice(query.hard_neg_samples) for query in queries] - elif formula.query_type == "1-chain": - neg_nodes = [random.choice(self.graph.full_lists[formula.target_mode]) for _ in queries] - else: - neg_nodes = [random.choice(query.neg_samples) for query in queries] - - affs = self.forward(formula, queries, [query.target_node for query in queries]) - neg_affs = self.forward(formula, queries, neg_nodes) - loss = margin - (affs - neg_affs) - loss = torch.clamp(loss, min=0) - loss = loss.mean() - return loss - -class SoftAndEncoderDecoder(nn.Module): - """ - Encoder decoder model that reasons about edges, metapaths and intersections - """ - - def __init__(self, graph, enc, path_dec): - super(SoftAndEncoderDecoder, self).__init__() - self.enc = enc - self.path_dec = path_dec - self.graph = graph - self.cos = nn.CosineSimilarity(dim=0) - - def forward(self, formula, queries, source_nodes): - if formula.query_type == "1-chain": - # a chain is simply a call to the path decoder - return self.path_dec.forward( - self.enc.forward(source_nodes, formula.target_mode), - self.enc.forward([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]), - formula.rels) - elif formula.query_type == "2-inter" or formula.query_type == "3-inter": - target_embeds = self.enc(source_nodes, formula.target_mode) - - embeds1 = self.enc([query.anchor_nodes[0] for query in queries], formula.anchor_modes[0]) - embeds1 = self.path_dec.project(embeds1, _reverse_relation(formula.rels[0])) - - embeds2 = self.enc([query.anchor_nodes[1] for query in queries], formula.anchor_modes[1]) - if len(formula.rels[1]) == 2: - for i_rel in formula.rels[1][::-1]: - embeds2 = self.path_dec.project(embeds2, _reverse_relation(i_rel)) - else: - embeds2 = self.path_dec.project(embeds2, _reverse_relation(formula.rels[1])) - - scores1 = self.cos(target_embeds, embeds1) - scores2 = self.cos(target_embeds, embeds2) - if formula.query_type == "3-inter": - embeds3 = self.enc([query.anchor_nodes[2] for query in queries], formula.anchor_modes[2]) - embeds3 = self.path_dec.project(embeds3, _reverse_relation(formula.rels[2])) - scores3 = self.cos(target_embeds, embeds2) - scores = scores1 * scores2 * scores3 - else: - scores = scores1 * scores2 - return scores - else: - raise Exception("Query type not supported for this model.") - - def margin_loss(self, formula, queries, hard_negatives=False, margin=1): - if not "inter" in formula.query_type and hard_negatives: - raise Exception("Hard negative examples can only be used with intersection queries") - elif hard_negatives: - neg_nodes = [random.choice(query.hard_neg_samples) for query in queries] - elif formula.query_type == "1-chain": - neg_nodes = [random.choice(self.graph.full_lists[formula.target_mode]) for _ in queries] - else: - neg_nodes = [random.choice(query.neg_samples) for query in queries] - - affs = self.forward(formula, queries, [query.target_node for query in queries]) - neg_affs = self.forward(formula, queries, neg_nodes) - loss = margin - (affs - neg_affs) - loss = torch.clamp(loss, min=0) - loss = loss.mean() - return loss From e9332d8ade53773ba2bc1e98cc885479884bd957 Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:54:53 -0500 Subject: [PATCH 7/9] delete after upload in wrong folder --- train.py | 81 -------------------------------------------------------- 1 file changed, 81 deletions(-) delete mode 100644 train.py diff --git a/train.py b/train.py deleted file mode 100644 index 808f695..0000000 --- a/train.py +++ /dev/null @@ -1,81 +0,0 @@ -from argparse import ArgumentParser - -from netquery.utils import * -from netquery.bio.data_utils import load_graph -from netquery.data_utils import load_queries_by_formula, load_test_queries_by_formula -from netquery.model import QueryEncoderDecoder -from netquery.train_helpers import run_train - -from torch import optim - -parser = ArgumentParser() -parser.add_argument("--embed_dim", type=int, default=128) -parser.add_argument("--data_dir", type=str, default="./bio_data/") -parser.add_argument("--lr", type=float, default=0.01) -parser.add_argument("--depth", type=int, default=0) -parser.add_argument("--batch_size", type=int, default=512) -parser.add_argument("--max_iter", type=int, default=100000000) -parser.add_argument("--max_burn_in", type=int, default=1000000) -parser.add_argument("--val_every", type=int, default=5000) -parser.add_argument("--tol", type=float, default=0.0001) -parser.add_argument("--cuda", action='store_true') -parser.add_argument("--log_dir", type=str, default="./") -parser.add_argument("--model_dir", type=str, default="./") -parser.add_argument("--decoder", type=str, default="bilinear") -parser.add_argument("--inter_decoder", type=str, default="mean") -parser.add_argument("--opt", type=str, default="adam") -args = parser.parse_args() - -print("Loading graph data..") -graph, feature_modules, node_maps = load_graph(args.data_dir, args.embed_dim) -if args.cuda: - graph.features = cudify(feature_modules, node_maps) -out_dims = {mode:args.embed_dim for mode in graph.relations} - -print("Loading edge data..") -train_queries = load_queries_by_formula(args.data_dir + "/train_edges.pkl") -val_queries = load_test_queries_by_formula(args.data_dir + "/val_edges.pkl") -test_queries = load_test_queries_by_formula(args.data_dir + "/test_edges.pkl") - -print("Loading query data..") -for i in range(2,4): - train_queries.update(load_queries_by_formula(args.data_dir + "/train_queries_{:d}.pkl".format(i))) - i_val_queries = load_test_queries_by_formula(args.data_dir + "/val_queries_{:d}.pkl".format(i)) - val_queries["one_neg"].update(i_val_queries["one_neg"]) - val_queries["full_neg"].update(i_val_queries["full_neg"]) - i_test_queries = load_test_queries_by_formula(args.data_dir + "/test_queries_{:d}.pkl".format(i)) - test_queries["one_neg"].update(i_test_queries["one_neg"]) - test_queries["full_neg"].update(i_test_queries["full_neg"]) - - -enc = get_encoder(args.depth, graph, out_dims, feature_modules, args.cuda) -dec = get_metapath_decoder(graph, enc.out_dims if args.depth > 0 else out_dims, args.decoder) -inter_dec = get_intersection_decoder(graph, out_dims, args.inter_decoder) - -enc_dec = QueryEncoderDecoder(graph, enc, dec, inter_dec) -if args.cuda: - enc_dec.cuda() - -if args.opt == "sgd": - optimizer = optim.SGD(filter(lambda p : p.requires_grad, enc_dec.parameters()), lr=args.lr, momentum=0) -elif args.opt == "adam": - optimizer = optim.Adam(filter(lambda p : p.requires_grad, enc_dec.parameters()), lr=args.lr) - -log_file = args.log_dir + "/{data:s}-{depth:d}-{embed_dim:d}-{lr:f}-{decoder:s}-{inter_decoder:s}.log".format( - data=args.data_dir.strip().split("/")[-1], - depth=args.depth, - embed_dim=args.embed_dim, - lr=args.lr, - decoder=args.decoder, - inter_decoder=args.inter_decoder) -model_file = args.model_dir + "/{data:s}-{depth:d}-{embed_dim:d}-{lr:f}-{decoder:s}-{inter_decoder:s}.log".format( - data=args.data_dir.strip().split("/")[-1], - depth=args.depth, - embed_dim=args.embed_dim, - lr=args.lr, - decoder=args.decoder, - inter_decoder=args.inter_decoder) -logger = setup_logging(log_file) - -run_train(enc_dec, optimizer, train_queries, val_queries, test_queries, logger, max_burn_in=args.max_burn_in, val_every=args.val_every, model_file=model_file) -torch.save(enc_dec.state_dict(), model_file) From cfe72d645f183f6697c20374968c76505587f093 Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:55:09 -0500 Subject: [PATCH 8/9] delete after upload in wrong folder --- train_helpers.py | 107 ----------------------------------------------- 1 file changed, 107 deletions(-) delete mode 100644 train_helpers.py diff --git a/train_helpers.py b/train_helpers.py deleted file mode 100644 index 690a710..0000000 --- a/train_helpers.py +++ /dev/null @@ -1,107 +0,0 @@ -import numpy as np -from netquery.utils import eval_auc_queries, eval_perc_queries -import torch - -def check_conv(vals, window=2, tol=1e-6): - if len(vals) < 2 * window: - return False - conv = np.mean(vals[-window:]) - np.mean(vals[-2*window:-window]) - return conv < tol - -def update_loss(loss, losses, ema_loss, ema_alpha=0.01): - losses.append(loss) - if ema_loss is None: - ema_loss = loss - else: - ema_loss = (1-ema_alpha)*ema_loss + ema_alpha*loss - return losses, ema_loss - -def run_eval(model, queries, iteration, logger, by_type=False): - vals = {} - def _print_by_rel(rel_aucs, logger): - for rels, auc in rel_aucs.items(): - logger.info(str(rels) + "\t" + str(auc)) - for query_type in queries["one_neg"]: - auc, rel_aucs = eval_auc_queries(queries["one_neg"][query_type], model) - perc = eval_perc_queries(queries["full_neg"][query_type], model) - vals[query_type] = auc - logger.info("{:s} val AUC: {:f} val perc {:f}; iteration: {:d}".format(query_type, auc, perc, iteration)) - if by_type: - _print_by_rel(rel_aucs, logger) - if "inter" in query_type: - auc, rel_aucs = eval_auc_queries(queries["one_neg"][query_type], model, hard_negatives=True) - perc = eval_perc_queries(queries["full_neg"][query_type], model, hard_negatives=True) - logger.info("Hard-{:s} val AUC: {:f} val perc {:f}; iteration: {:d}".format(query_type, auc, perc, iteration)) - if by_type: - _print_by_rel(rel_aucs, logger) - vals[query_type + "hard"] = auc - return vals - -def run_train(model, optimizer, train_queries, val_queries, test_queries, logger, - max_burn_in =100000, batch_size=512, log_every=100, val_every=1000, tol=1e-6, - max_iter=int(10e7), inter_weight=0.005, path_weight=0.01, model_file=None): - edge_conv = False - ema_loss = None - vals = [] - losses = [] - conv_test = None - for i in range(max_iter): - - optimizer.zero_grad() - loss = run_batch(train_queries["1-chain"], model, i, batch_size) - if not edge_conv and (check_conv(vals) or len(losses) >= max_burn_in): - logger.info("Edge converged at iteration {:d}".format(i-1)) - logger.info("Testing at edge conv...") - conv_test = run_eval(model, test_queries, i, logger) - conv_test = np.mean(conv_test.values()) - edge_conv = True - losses = [] - ema_loss = None - vals = [] - if not model_file is None: - torch.save(model.state_dict(), model_file+"-edge_conv") - - if edge_conv: - for query_type in train_queries: - if query_type == "1-chain": - continue - if "inter" in query_type: - loss += inter_weight*run_batch(train_queries[query_type], model, i, batch_size) - loss += inter_weight*run_batch(train_queries[query_type], model, i, batch_size, hard_negatives=True) - else: - loss += path_weight*run_batch(train_queries[query_type], model, i, batch_size) - if check_conv(vals): - logger.info("Fully converged at iteration {:d}".format(i)) - break - - losses, ema_loss = update_loss(loss.data[0], losses, ema_loss) - loss.backward() - optimizer.step() - - if i % log_every == 0: - logger.info("Iter: {:d}; ema_loss: {:f}".format(i, ema_loss)) - - if i >= val_every and i % val_every == 0: - v = run_eval(model, val_queries, i, logger) - if edge_conv: - vals.append(np.mean(v.values())) - else: - vals.append(v["1-chain"]) - - v = run_eval(model, test_queries, i, logger) - logger.info("Test macro-averaged val: {:f}".format(np.mean(v.values()))) - logger.info("Improvement from edge conv: {:f}".format((np.mean(v.values())-conv_test)/conv_test)) - -def run_batch(train_queries, enc_dec, iter_count, batch_size, hard_negatives=False): - num_queries = [float(len(queries)) for queries in train_queries.values()] - denom = float(sum(num_queries)) - formula_index = np.argmax(np.random.multinomial(1, - np.array(num_queries)/denom)) - formula = train_queries.keys()[formula_index] - n = len(train_queries[formula]) - start = (iter_count * batch_size) % n - end = min(((iter_count+1) * batch_size) % n, n) - end = n if end <= start else end - queries = train_queries[formula][start:end] - loss = enc_dec.margin_loss(formula, queries, hard_negatives=hard_negatives) - return loss From 3169f759ad9ed9f137cbc2b79c3883a85d3dc20a Mon Sep 17 00:00:00 2001 From: Natalie Date: Tue, 20 Aug 2019 10:55:19 -0500 Subject: [PATCH 9/9] delete after upload in wrong folder --- utils.py | 167 ------------------------------------------------------- 1 file changed, 167 deletions(-) delete mode 100644 utils.py diff --git a/utils.py b/utils.py deleted file mode 100644 index e652efd..0000000 --- a/utils.py +++ /dev/null @@ -1,167 +0,0 @@ -import numpy as np -import scipy -import scipy.stats as stats -import torch -from sklearn.metrics import roc_auc_score -from netquery.decoders import BilinearMetapathDecoder, TransEMetapathDecoder, BilinearDiagMetapathDecoder, SetIntersection, SimpleSetIntersection -from netquery.encoders import DirectEncoder, Encoder -from netquery.aggregators import MeanAggregator -#import cPickle as pickle -import pickle -import logging -import random - -""" -Misc utility functions.. -""" - -def cudify(feature_modules, node_maps=None): - if node_maps is None: - features = lambda nodes, mode : feature_modules[mode]( - torch.autograd.Variable(torch.LongTensor(nodes)+1).cuda()) - else: - features = lambda nodes, mode : feature_modules[mode]( - torch.autograd.Variable(torch.LongTensor([node_maps[mode][n] for n in nodes])+1).cuda()) - return features - -def _get_perc_scores(scores, lengths): - perc_scores = [] - cum_sum = 0 - neg_scores = scores[len(lengths):] - for i, length in enumerate(lengths): - perc_scores.append(stats.percentileofscore(neg_scores[cum_sum:cum_sum+length], scores[i])) - cum_sum += length - return perc_scores - -def eval_auc_queries(test_queries, enc_dec, batch_size=1000, hard_negatives=False, seed=0): - predictions = [] - labels = [] - formula_aucs = {} - random.seed(seed) - for formula in test_queries: - formula_labels = [] - formula_predictions = [] - formula_queries = test_queries[formula] - offset = 0 - while offset < len(formula_queries): - max_index = min(offset+batch_size, len(formula_queries)) - batch_queries = formula_queries[offset:max_index] - if hard_negatives: - lengths = [1 for j in range(offset, max_index)] - negatives = [random.choice(formula_queries[j].hard_neg_samples) for j in range(offset, max_index)] - else: - lengths = [1 for j in range(offset, max_index)] - negatives = [random.choice(formula_queries[j].neg_samples) for j in range(offset, max_index)] - offset += batch_size - - formula_labels.extend([1 for _ in range(len(lengths))]) - batch_scores = enc_dec.forward(formula, - batch_queries+[b for i, b in enumerate(batch_queries) for _ in range(lengths[i])], - [q.target_node for q in batch_queries] + negatives) - batch_scores = batch_scores.data.tolist() - formula_predictions.extend(batch_scores) - formula_aucs[formula] = roc_auc_score(formula_labels, np.nan_to_num(formula_predictions)) - labels.extend(formula_labels) - predictions.extend(formula_predictions) - overall_auc = roc_auc_score(labels, np.nan_to_num(predictions)) - return overall_auc, formula_aucs - - -def eval_perc_queries(test_queries, enc_dec, batch_size=1000, hard_negatives=False): - perc_scores = [] - for formula in test_queries: - formula_queries = test_queries[formula] - offset = 0 - while offset < len(formula_queries): - max_index = min(offset+batch_size, len(formula_queries)) - batch_queries = formula_queries[offset:max_index] - if hard_negatives: - lengths = [len(formula_queries[j].hard_neg_samples) for j in range(offset, max_index)] - negatives = [n for j in range(offset, max_index) for n in formula_queries[j].hard_neg_samples] - else: - lengths = [len(formula_queries[j].neg_samples) for j in range(offset, max_index)] - negatives = [n for j in range(offset, max_index) for n in formula_queries[j].neg_samples] - offset += batch_size - - batch_scores = enc_dec.forward(formula, - batch_queries+[b for i, b in enumerate(batch_queries) for _ in range(lengths[i])], - [q.target_node for q in batch_queries] + negatives) - batch_scores = batch_scores.data.tolist() - perc_scores.extend(_get_perc_scores(batch_scores, lengths)) - return np.mean(perc_scores) - -def get_encoder(depth, graph, out_dims, feature_modules, cuda): - if depth < 0 or depth > 3: - raise Exception("Depth must be between 0 and 3 (inclusive)") - - if depth == 0: - enc = DirectEncoder(graph.features, feature_modules) - else: - aggregator1 = MeanAggregator(graph.features) - enc1 = Encoder(graph.features, - graph.feature_dims, - out_dims, - graph.relations, - graph.adj_lists, feature_modules=feature_modules, - cuda=cuda, aggregator=aggregator1) - enc = enc1 - if depth >= 2: - aggregator2 = MeanAggregator(lambda nodes, mode : enc1(nodes, mode).t().squeeze()) - enc2 = Encoder(lambda nodes, mode : enc1(nodes, mode).t().squeeze(), - enc1.out_dims, - out_dims, - graph.relations, - graph.adj_lists, base_model=enc1, - cuda=cuda, aggregator=aggregator2) - enc = enc2 - if depth >= 3: - aggregator3 = MeanAggregator(lambda nodes, mode : enc2(nodes, mode).t().squeeze()) - enc3 = Encoder(lambda nodes, mode : enc1(nodes, mode).t().squeeze(), - enc2.out_dims, - out_dims, - graph.relations, - graph.adj_lists, base_model=enc2, - cuda=cuda, aggregator=aggregator3) - enc = enc3 - return enc - -def get_metapath_decoder(graph, out_dims, decoder): - if decoder == "bilinear": - dec = BilinearMetapathDecoder(graph.relations, out_dims) - elif decoder == "transe": - dec = TransEMetapathDecoder(graph.relations, out_dims) - elif decoder == "bilinear-diag": - dec = BilinearDiagMetapathDecoder(graph.relations, out_dims) - else: - raise Exception("Metapath decoder not recognized.") - return dec - -def get_intersection_decoder(graph, out_dims, decoder): - if decoder == "mean": - dec = SetIntersection(out_dims, out_dims, agg_func=torch.mean) - elif decoder == "mean-simple": - dec = SimpleSetIntersection(agg_func=torch.mean) - elif decoder == "min": - dec = SetIntersection(out_dims, out_dims, agg_func=torch.min) - elif decoder == "min-simple": - dec = SimpleSetIntersection(agg_func=torch.min) - else: - raise Exception("Intersection decoder not recognized.") - return dec - -def setup_logging(log_file, console=True): - logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', - filename=log_file, - filemode='w') - if console: - console = logging.StreamHandler() - # optional, set the logging level - console.setLevel(logging.INFO) - # set a format which is the same for console use - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') - # tell the handler to use this format - console.setFormatter(formatter) - # add the handler to the root logger - logging.getLogger('').addHandler(console) - return logging