-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmetrics.py
More file actions
39 lines (33 loc) · 1 KB
/
metrics.py
File metadata and controls
39 lines (33 loc) · 1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""
Source: https://sklearn-crfsuite.readthedocs.io/en/latest/_modules/sklearn_crfsuite/metrics.html
"""
from functools import wraps
from itertools import chain
def flatten(y):
"""
Flatten a list of lists.
>>> flatten([[1,2], [3,4]])
[1, 2, 3, 4]
"""
return list(chain.from_iterable(y))
def _flattens_y(func):
@wraps(func)
def wrapper(y_true, y_pred, *args, **kwargs):
y_true_flat = flatten(y_true)
y_pred_flat = flatten(y_pred)
return func(y_true_flat, y_pred_flat, *args, **kwargs)
return wrapper
@_flattens_y
def flat_classification_report(y_true, y_pred, labels=None, **kwargs):
"""
Return classification report for sequence items.
"""
from sklearn import metrics
return metrics.classification_report(y_true, y_pred, labels, **kwargs)
@_flattens_y
def flat_f1_score(y_true, y_pred, **kwargs):
"""
Return F1 score for sequence items.
"""
from sklearn import metrics
return metrics.f1_score(y_true, y_pred, **kwargs)