-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
106 lines (87 loc) · 4.12 KB
/
evaluate.py
File metadata and controls
106 lines (87 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
import numpy as np
def calculate_average(file_path, dict):
total_acc_mces = 0
total_mse_sim = 0
count = 0
with open(file_path, 'r') as file:
for line in file:
numbers = line.split()
if len(numbers) > 5:
numbers = numbers[:5]
if len(numbers) == 5:
total_acc_mces += float(numbers[1])
total_mse_sim += float(numbers[2])
count += 1
if int(numbers[0]) not in dict:
dict[int(numbers[0])] = (float(numbers[1]), float(numbers[2]), float(numbers[3]), float(numbers[4]))
else:
if float(numbers[1]) > dict[int(numbers[0])][0]:
dict[int(numbers[0])] = (max(dict[int(numbers[0])][0], float(numbers[1])), min(dict[int(numbers[0])][1], float(numbers[2])), float(numbers[3]), float(numbers[4]))
if count == 0:
return 0
return total_acc_mces / count, total_mse_sim / count, dict
def evaluate_retrieval(dict):
dict = {k: dict[k] for k in sorted(dict.keys())}
length = len(dict)
mrr_total = 0
p10_total = 0
map_total = 0
num = length // 100
for i in range(num):
pred_sim = [value[2] for value in dict.values()][i*100:(i+1)*100]
true_sim = [value[3] for value in dict.values()][i*100:(i+1)*100]
pred = np.array(pred_sim)
target = np.array(true_sim)
max_value = target.max()
max_indices = np.where(target == max_value)[0]
chosen_index = max_indices[0]
rank = np.where(np.argsort(pred)[::-1] == chosen_index)[0][0] + 1
mrr = 1 / rank
mrr_total += mrr
p10 = precision_at_k(np.array(pred_sim), (np.array(true_sim)>0.5).astype(np.int8))
p10_total += p10
map = calculate_map(torch.tensor(pred_sim).unsqueeze(0), (torch.tensor(true_sim)>0.5).unsqueeze(0))
map_total += map
return mrr_total/num, p10_total/num, map_total/num
def calculate_average_multi(file_paths):
dict = {}
for file_path in file_paths:
_, _, dict = calculate_average(file_path, dict)
acc_mces = [value[0] for value in dict.values()]
mse_sim = [value[1] for value in dict.values()]
return sum(acc_mces) / len(dict), sum(mse_sim) / len(dict), dict
def precision_at_k(y_pred, y_true, k=10):
sorted_indices = np.argsort(y_pred)[::-1]
top_k_indices = sorted_indices[:k]
relevant_at_k = y_true[top_k_indices]
precision = np.sum(relevant_at_k) / k
return precision
def calculate_precision_at_k(pred_relevance, true_relevance, k):
_, top_k_indices = torch.topk(pred_relevance, k=min(k, pred_relevance.size(1)), dim=1)
top_k_relevant = torch.gather(true_relevance, 1, top_k_indices)
precision = torch.sum(top_k_relevant, dim=1) / k
return precision
def calculate_ap(pred_relevance, true_relevance):
num_relevant = torch.sum(true_relevance, dim=1)
num_docs = pred_relevance.size(1)
precisions = torch.zeros_like(pred_relevance)
for k in range(1, num_docs + 1):
precisions[:, k-1] = calculate_precision_at_k(pred_relevance, true_relevance, k)
_, pred_indices = torch.sort(pred_relevance, dim=1, descending=True)
sorted_relevance = torch.gather(true_relevance, 1, pred_indices)
ap = torch.sum(precisions * sorted_relevance, dim=1) / torch.clamp(num_relevant, min=1)
return ap
def calculate_map(pred_relevance, true_relevance):
ap = calculate_ap(pred_relevance, true_relevance)
map_score = torch.mean(ap)
return map_score
if __name__ == "__main__":
# file_paths = ['result_MCF-7_1.log', 'result_MCF-7_2.log', 'result_MCF-7_3.log']
file_paths = ['result_AIDS_1.log', 'result_AIDS_2.log', 'result_AIDS_3.log']
# file_paths = ['result_MOLHIV_1.log', 'result_MOLHIV_2.log', 'result_MOLHIV_3.log']
average_acc_mces, average_mse_sim, dict = calculate_average_multi(file_paths)
print(f"The average acc is: {average_acc_mces}, the average mse is: {average_mse_sim}")
# evaluate retrieval
# mrr, p10, map = evaluate_retrieval(dict)
# print(f"MRR: {mrr:.5f}, P@10: {p10:.5f}, MAP: {map:.5f}")