-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathutils.py
More file actions
151 lines (123 loc) · 5.15 KB
/
utils.py
File metadata and controls
151 lines (123 loc) · 5.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import itertools as it
from collections import defaultdict
import networkx as nx
import numpy as np
def load_graphs(config, logger):
# dataset_dir = args.data_dir+args.dataset+'/'
dataset_dir = f'{config['data_dir']}/{config['dataset']}/'
graphs = {}
graphs['simplicies_train'] = read_simplicies(dataset_dir, mode='train')
graphs['G_train'] = construct_graph(graphs['simplicies_train'])
graphs['G_weighted'] = construct_weighted_graph(graphs['simplicies_train'])
# graphs['simplicies_test'] = read_simplicies(dataset_dir, mode='test')
# graphs['G_test'] = construct_graph(graphs['simplicies_test'])
logger.info('Finish loading graphs.')
logger.info(f'Nodes train: {graphs['G_train'].number_of_nodes()}')
logger.info(f'Simplicies train: {len(graphs['simplicies_train'])}')
return graphs
def read_simplicies(file_dir, mode='train'):
simplicies = []
with open(file_dir + '{}.txt'.format(mode), 'r') as f:
for line in f.readlines():
simplicies.append(tuple(sorted(set([int(node) for node in line.strip().split(' ')]))))
print(f'number of simplices: {len((simplicies))}')
print(f'number of unique simplices: {len(set(simplicies))}')
try:
assert (len(set(simplicies)) == len(simplicies)) # no duplicate
nodes = set([node for simplex in simplicies for node in simplex])
assert(min(nodes) == 0)
assert (max(nodes) == len(nodes)-1) # compact indexing
# sanity check
except AssertionError:
print('Node index should begin with 0, reindexing the hypergraphs ...')
all_nodes = sorted(set(n for s in simplicies for n in s))
node2i = {node: i for i, node in enumerate(all_nodes)}
simplicies = [tuple(sorted(set([node2i[n] for n in s]))) for s in simplicies]
return set(simplicies)
def construct_graph(simplicies):
G = nx.Graph()
for s in simplicies:
if len(s) == 1:
G.add_node(s[0])
continue
for e in it.combinations(s, 2):
G.add_edge(*e)
print('number of nodes in construct graph', G.number_of_nodes())
return G
def construct_weighted_graph(simplicies):
G = nx.Graph()
for s in simplicies:
if len(s) == 1:
G.add_node(s[0])
continue
for u, v in it.combinations(s, 2):
if G.has_edge(u, v):
G[u][v]['weight'] += 1
else:
G.add_edge(u, v, weight=1)
return G
def construct_decomposed_graph(simplicies, level=2, limit=7):
G = nx.Graph()
for s in simplicies:
if len(s) <= limit and len(s) >= level:
hyper_simplicies = list(it.combinations(s, level))
if len(hyper_simplicies) == 1:
G.add_node(hyper_simplicies[0])
continue
for e in it.combinations(hyper_simplicies, 2):
G.add_edge(*e)
return G
def get_edges(clique):
return it.combinations(clique, 2)
def get_clique_gain(weighted_adjacency_matrix, clique):
gain = 0
edges = get_edges(clique)
for x, y in edges:
if weighted_adjacency_matrix[x, y] > 0:
gain += 2
else:
gain -= 2
return gain
def lazy_clique_edge_cover(
weighted_adjacency_matrix, clique_candidates, cliques_quota
):
# build edge table
edge_table = defaultdict(list)
for idx, clique in enumerate(clique_candidates):
edges = get_edges(clique)
for edge in edges:
edge_table[frozenset(edge)].append(idx)
# default clique gain
def get_default_gain(x):
return len(x) * (len(x) - 1)
clique_gain = np.array([get_default_gain(clique) for clique in clique_candidates])
# the greedy process (import probability in here?)
current_clique_idxs = []
for i in range(cliques_quota):
clique_gain[current_clique_idxs] = -10000
best_clique_idx = np.argmax(clique_gain)
current_clique_idxs.append(best_clique_idx)
best_clique = clique_candidates[best_clique_idx]
best_clique_edges = list(get_edges(best_clique))
# update the weighted_adjacency_matrix
to_update_edges = []
for edge in best_clique_edges:
x, y = edge
weighted_adjacency_matrix[x, y] -= 1
weighted_adjacency_matrix[y, x] -= 1
if weighted_adjacency_matrix[x, y] == 0:
to_update_edges.append(edge)
# find the realted clique and update the clique gain
for edge in to_update_edges:
edge = frozenset(edge)
cliques_idx = edge_table[edge]
for idx in cliques_idx:
clique_gain[idx] = clique_gain[idx] - 4
return [clique_candidates[idx] for idx in current_clique_idxs]
def get_performance_wrt_ground_truth(reconstructed, ground_truth):
correct_cliques = reconstructed & ground_truth
precision = len(correct_cliques) / len(reconstructed) if len(reconstructed)>0 else 0
recall = len(correct_cliques) / len(ground_truth)
f1 = 2 * precision * recall / (precision + recall) if precision * recall > 0 else 0
jaccard = len(correct_cliques) / len(reconstructed | ground_truth)
return precision, recall, f1, jaccard