-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpredict_edges.py
More file actions
100 lines (79 loc) · 3.8 KB
/
predict_edges.py
File metadata and controls
100 lines (79 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import numpy as np
import argparse
import os
from tqdm import trange, tqdm
from train_edges_ood_id import EdgePredictor
def predict_batch(model, ood_embedding, id_embedding, batch_pairs, device):
torch.from_numpy
batch_tensors = [torch.cat([ood_embedding[i], id_embedding[j]]) for i, j in batch_pairs]
batch_pairs_tensor = torch.stack(batch_tensors).to(device)
with torch.no_grad():
batch_scores = model(batch_pairs_tensor).squeeze()
if batch_scores.ndim == 0:
batch_scores = batch_scores.unsqueeze(0)
return batch_scores
def update_edge_scores(edge_scores, batch_scores, batch_pairs):
for idx, score in enumerate(batch_scores):
i, j = batch_pairs[idx]
edge_scores[i, j] = score
def predict_edges(args):
ood_embedding = torch.from_numpy(np.load(args.dataset + '/' + args.dataset + '_ood_embs.npy')).to(torch.float32)
id_embedding = torch.from_numpy(np.load(args.dataset + '/' + args.dataset + '_id_embs.npy')).to(torch.float32)
m, n = ood_embedding.shape[0], id_embedding.shape[0]
# m = m // 100
# n = n // 100
edge_scores = torch.zeros(m, n)
if os.path.exists(args.dataset + '/' + args.dataset + '_edge_scores.npy'):
edge_scores = torch.from_numpy(np.load(args.dataset + '/' + args.dataset + '_edge_scores.npy'))
top_k_edges = []
for i in range(m):
scores = edge_scores[i]
scores[i] = -float('inf')
top_k_indices = torch.topk(scores, args.top_k).indices
for idx in top_k_indices:
top_k_edges.append([n + i, idx])
top_k_edges.append([idx, n + i])
top_k_edges = np.unique(top_k_edges, axis=0).transpose()
# import ipdb; ipdb.set_trace()
return top_k_edges
model = EdgePredictor(args.embedding_dim, args.hidden_dim)
# if args.edge_type == 'ood_id':
model.load_state_dict(torch.load(args.dataset + '/' + args.dataset + '_' + args.edge_type + '.pth'))
model.to(args.device)
model.eval()
for i in trange(m):
batch_pairs = []
for j in range(n):
batch_pairs.append((i, j))
if len(batch_pairs) == args.batch_size:
batch_scores = predict_batch(model, ood_embedding, id_embedding, batch_pairs, args.device)
update_edge_scores(edge_scores, batch_scores, batch_pairs)
batch_pairs = []
if batch_pairs:
batch_scores = predict_batch(model, ood_embedding, id_embedding, batch_pairs, args.device)
update_edge_scores(edge_scores, batch_scores, batch_pairs)
top_k_edges = []
for i in range(m):
scores = edge_scores[i]
scores[i] = -float('inf')
top_k_indices = torch.topk(scores, args.top_k).indices
for idx in top_k_indices:
top_k_edges.append([i, idx])
top_k_edges.append([idx, i])
top_k_edges = np.unique(top_k_edges, axis=0).transpose()
np.save(args.dataset + '/' + args.dataset + '_edge_scores.npy', edge_scores)
# np.save(args.dataset + '/' + args.dataset + '_edges_recon_' + str(args.top_k) + '.npy', top_k_edges)
return top_k_edges
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--top_k', type=int, default=3)
parser.add_argument('--embedding_dim', type=int, default=4096)
parser.add_argument('--batch_size', type=int, default=8192)
parser.add_argument('--hidden_dim', type=int, default=1024)
parser.add_argument('--combine', type=bool, default=False)
parser.add_argument('--edge_type', type=str, default='ood_id')
args = parser.parse_args()
predict_edges(args)