-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
132 lines (97 loc) · 3.98 KB
/
utils.py
File metadata and controls
132 lines (97 loc) · 3.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 15 15:04:21 2018
@author: Samuele Garda
"""
import os
import matplotlib.pyplot as plt
import itertools
import numpy as np
from tweets_classification import HierarchicalClassifier
from sklearn.linear_model import LogisticRegression
from sklearn import svm
# TAKEN FROM :
# http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html#sphx-glr-auto-examples-model-selection-plot-confusion-matrix-py
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
:params:
cm (np.ndarray) : confusion matrix
normalize (bool) : normalize counts by # of class instance
title (string) : plot title
cmap (matplotlib.colors.LinearSegmentedColormap) : color map for image
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
else:
pass
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
def store_hyperparameters(clf,text):
"""
Add to list of info to be printed hyperparameters of classifiers.
:params:
clf (classifier) : classifier
text (list) : list to which append infos
"""
def base_clf_hp(clf,text):
if isinstance(clf,svm.LinearSVC) or isinstance(clf,LogisticRegression):
text.append("Hyperparameters: C = {}\n".format(clf.C))
elif isinstance(clf,svm.SVC):
text.append("Hyperparameters: C = {}, gamma = {}\n".format(clf.C,clf.gamma))
if isinstance(clf,HierarchicalClassifier):
for clf in clf.clfs:
base_clf_hp(clf,text)
else:
base_clf_hp(clf,text)
def by_class_error_analysis(df,y_true,y_pred,limit,error,out_path):
"""
Write to file randomly selected False Positive or False Negative. For multiclass FP,FN are estimated in one-vs-all.
:params:
df (pandas.DataFrame) : data set having `toks` column
y_true (array) : original labels
y_pred (array) : predicted labels
limit (int) : # of FP/FN to be printed
error (str) : type of misclassification. Choices : `FP` (false positive), `FN` (false negative)
out_path (str) : folder where errors will be saved
"""
errors = ['FP','FN']
assert error in errors, "Invalid error choice! Received `{}` : choose from `{}`! ".format(error,errors)
if error == 'FP':
out_file = open(os.path.join(out_path, 'error.FP'),'w+')
elif error == 'FN' :
out_file = open(os.path.join(out_path, 'error.FN'),'w+')
unique_labels = np.unique(y_true)
y_true = np.asarray(y_true)
y_pred = np.asarray(y_pred)
for label in unique_labels:
out_file.write("{}\n".format(str(label).upper()))
if error == 'FP':
error_idx = np.where((y_true!=label) & (y_pred==label))[0] #take indices
elif error == 'FN':
error_idx = np.where((y_true==label) & (y_pred!=label))[0] #take indices
if len(error_idx) < 1:
out_file.write("No {}\n".format(error))
else:
error_idx = error_idx if len(error_idx) <= limit else np.random.choice(error_idx, size = limit)
for e_idx in error_idx:
out_file.write("Tweet : `{}` - true : `{}` - pred : `{}`\n".format(' '.join(df.toks[e_idx]),y_true[e_idx],y_pred[e_idx]))
out_file.close()