forked from iyempissy/privGnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
79 lines (68 loc) · 2.15 KB
/
utils.py
File metadata and controls
79 lines (68 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import absolute_import
import os
import sys
import errno
import pickle
import shutil
import json
import os.path as osp
import numpy as np
import torch
import scipy
sys.path.append('.')
from sklearn.decomposition import PCA, KernelPCA
def mkdir_if_missing(directory):
if not osp.exists(directory):
try:
os.makedirs(directory)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def pca(teacher, student):
pca = PCA(n_components=200)
pca.fit(teacher)
max_component = pca.components_.T
teacher = np.dot(teacher, max_component)
student = np.dot(student, max_component)
return student, teacher
def Hamming_Score(y_true, y_pred, torch=True,cate=False):
"""
torch = true mean y_pred is torch tensor
if torch=false mean y_pred=numpy
"""
acc_list = []
if torch:
from sklearn.metrics import accuracy_score
for i in range(len(y_true)):
if torch:
summary = y_true[i] == y_pred[i].double()
num = np.sum(summary.numpy())
else:
summary = y_true[i] == y_pred[i]
num = np.sum(summary)
tmp_a = num / float(len(y_true[i]))
acc_list.append(tmp_a)
#print('mean score from hamming',np.mean(acc_list))
return np.mean(acc_list)
def hamming_precision(y_true, y_pred,torch = True, cate = True):
acc_list = []
if torch:
from sklearn.metrics import accuracy_score
y_true = y_true.numpy()
y_pred = y_pred.numpy()
for i in range(len(y_true)):
set_true = set( np.where(y_true[i]==1)[0] )
set_pred = set( np.where(y_pred[i]==1)[0] )
tmp_a = None
if len(set_true) == 0 and len(set_pred) == 0:
tmp_a = 1
else:
tmp_a = len(set_true.intersection(set_pred))/\
float( len(set_true.union(set_pred)) )
acc_list.append(tmp_a)
return np.mean(acc_list)
def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
mkdir_if_missing(osp.dirname(fpath))
torch.save(state, fpath)
if is_best:
shutil.copy(fpath, osp.join(osp.dirname(fpath), 'best_model.pth.tar'))