-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathget_sim.py
More file actions
63 lines (51 loc) · 2.18 KB
/
get_sim.py
File metadata and controls
63 lines (51 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import argparse
import json
import logging
import math
import os
from collections import defaultdict
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from models.baseModel import baseModel
from util.dataloader import relationDataset
args = argparse.ArgumentParser()
args.add_argument('--model_path', required=True, help='path to directory that contains model file')
args.add_argument('--input', required=True, help='path to directory that contains input files')
args.add_argument('--output', required=True, help='path to directory that output files will be stored')
args.add_argument('--sample_num', default=20, help='how many samples are used when approximating KL divergence')
args.add_argument('--gpu', default='0', help='ID of the gpu you want to assign')
args = vars(args.parse_args())
os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s : %(message)s')
if not os.path.exists(args['output']):
os.mkdir(args['output'])
with open(os.path.join(args['model_path'], 'info.json'), 'r') as f:
info = json.load(f)
logging.info("Loading model")
model = baseModel(info, info['tot_rel'], info['tot_ent'])
flag = False
model.load_state_dict(torch.load(os.path.join(args['model_path'], 'model.pth')))
flag = True
if not flag:
raise IOError('Model file not found.')
model.cuda()
dataset = relationDataset(os.path.join(args['input'], 'train2id.txt'), os.path.join(args['input'], 'entity2id.txt'),
os.path.join(args['input'], 'relation2id.txt'))
logging.info("Calculating similarity")
scores = []
for i in tqdm(range(model.tot_rel)):
score = model.eval_step(i, args['sample_num'])
scores.append((score/args['sample_num']).tolist())
scores_tensor = torch.Tensor(scores)
scores = torch.max(scores_tensor, scores_tensor.transpose(1, 0)).tolist()
with open(os.path.join(args['output'], "kl_prob.json"), 'w') as f:
json.dump(scores, f)
with open(os.path.join(args['output'], 'kl_prob.txt'), 'w') as f:
for idx1, item in enumerate(scores):
for idx2, num in enumerate(item):
if idx1 == idx2:
continue
f.write(str(num) + ' ')
f.write('\n')