-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathchalleng_score.py
More file actions
143 lines (117 loc) · 4.82 KB
/
challeng_score.py
File metadata and controls
143 lines (117 loc) · 4.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import numpy as np
def is_number(x):
try:
float(x)
return True
except (ValueError, TypeError):
return False
def is_finite_number(x):
if is_number(x):
return np.isfinite(float(x))
else:
return False
# Load a table with row and column names.
def load_table(table_file):
# The table should have the following form:
#
# , a, b, c
# a, 1.2, 2.3, 3.4
# b, 4.5, 5.6, 6.7
# c, 7.8, 8.9, 9.0
#
table = list()
with open(table_file, 'r') as f:
for i, l in enumerate(f):
arrs = [arr.strip() for arr in l.split(',')]
table.append(arrs)
# Define the numbers of rows and columns and check for errors.
num_rows = len(table)-1
if num_rows<1:
raise Exception('The table {} is empty.'.format(table_file))
row_lengths = set(len(table[i])-1 for i in range(num_rows))
if len(row_lengths)!=1:
raise Exception('The table {} has rows with different lengths.'.format(table_file))
num_cols = min(row_lengths)
if num_cols<1:
raise Exception('The table {} is empty.'.format(table_file))
# Find the row and column labels.
rows = [table[0][j+1] for j in range(num_rows)]
cols = [table[i+1][0] for i in range(num_cols)]
# Find the entries of the table.
values = np.zeros((num_rows, num_cols), dtype=np.float64)
for i in range(num_rows):
for j in range(num_cols):
value = table[i+1][j+1]
if is_finite_number(value):
values[i, j] = float(value)
else:
values[i, j] = float('nan')
return rows, cols, values
# Load weights.
def load_weights(weight_file):
# Load the table with the weight matrix.
rows, cols, values = load_table(weight_file)
# Split the equivalent classes.
rows = [set(row.split('|')) for row in rows]
cols = [set(col.split('|')) for col in cols]
assert(rows == cols)
# Identify the classes and the weight matrix.
classes = rows
weights = values
return classes, weights
# Compute a modified confusion matrix for multi-class, multi-label tasks.
def compute_modified_confusion_matrix(labels, outputs):
# Compute a binary multi-class, multi-label confusion matrix, where the rows
# are the labels and the columns are the outputs.
num_recordings, num_classes = np.shape(labels)
A = np.zeros((num_classes, num_classes))
# Iterate over all of the recordings.
for i in range(num_recordings):
# Calculate the number of positive labels and/or outputs.
normalization = float(max(np.sum(np.any((labels[i, :], outputs[i, :]), axis=0)), 1))
# Iterate over all of the classes.
for j in range(num_classes):
# Assign full and/or partial credit for each positive class.
if labels[i, j]:
for k in range(num_classes):
if outputs[i, k]:
A[j, k] += 1.0/normalization
return A
# Compute the evaluation metric for the Challenge.
def compute_challenge_metric(weights, labels, outputs, classes, sinus_rhythm):
num_recordings, num_classes = np.shape(labels)
if sinus_rhythm in classes:
sinus_rhythm_index = classes.index(sinus_rhythm)
else:
raise ValueError('The sinus rhythm class is not available.')
# Compute the observed score.
A = compute_modified_confusion_matrix(labels, outputs)
observed_score = np.nansum(weights * A)
# Compute the score for the model that always chooses the correct label(s).
correct_outputs = labels
A = compute_modified_confusion_matrix(labels, correct_outputs)
correct_score = np.nansum(weights * A)
# Compute the score for the model that always chooses the sinus rhythm class.
inactive_outputs = np.zeros((num_recordings, num_classes), dtype=np.bool_)
inactive_outputs[:, sinus_rhythm_index] = 1
A = compute_modified_confusion_matrix(labels, inactive_outputs)
inactive_score = np.nansum(weights * A)
if correct_score != inactive_score:
normalized_score = float(observed_score - inactive_score) / float(correct_score - inactive_score)
else:
normalized_score = 0.0
return normalized_score
def evaluate_model(labels, binary_outputs):
# Identify the weights and the SNOMED CT code for the sinus rhythm class.
weights_file = 'weights.csv'
sinus_rhythm = set(['426783006'])
# Load the scored classes and the weights for the Challenge metric.
print('Loading weights...')
classes, weights = load_weights(weights_file)
# Evaluate the model by comparing the labels and outputs.
print('Evaluating model...')
print('- Challenge metric...')
challenge_metric = compute_challenge_metric(weights, labels, binary_outputs, classes, sinus_rhythm)
print('Done.')
# Return the results.
return classes, challenge_metric