-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfeature_selection.py
More file actions
51 lines (41 loc) · 1.86 KB
/
feature_selection.py
File metadata and controls
51 lines (41 loc) · 1.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""
Explore scikitlearn's feature selection capabilities.
"""
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import VarianceThreshold, SelectKBest, f_classif, SelectFdr, SelectFpr, RFECV, RFE
# from load_data import LoadData
import pandas as pd
def _select_features(df, selector=VarianceThreshold, **kwargs):
"""Select features to be used for classification based on some
scikit-learn feature selector.
Arguments:
df {pandas.DataFrame} -- The DataFrame from which we select
columns.
Keyword Arguments:
selector {scikit-learn feature selector} -- The feature
selection class to use for feature selection
(default: {VarianceThreshold})
Returns:
list -- The names of the selected features (columns in the
input DataFrame).
"""
feature_selector = selector(**kwargs).fit(df)
return _supported_cols(df, feature_selector)
def variance(df, threshold=0.125):
return _select_features(df, VarianceThreshold, threshold=threshold)
def univariate(features, labels, method=SelectKBest, metric=f_classif, **kwargs):
labels = _squash_columns(labels)
selected = method(metric, **kwargs).fit(features, labels)
return _supported_cols(features, selected)
def elimination(features, labels, classifier, eliminator=RFE, **kwargs):
labels = _squash_columns(labels)
selected = eliminator(classifier, **kwargs).fit(features, labels)
return _supported_cols(features, selected)
def _squash_columns(labels):
return pd.DataFrame(labels).apply(lambda x: ",".join(x.fillna("none").astype(str)), axis=1)
def _supported_cols(features, selected):
return features[features.columns[selected.get_support(indices=True)]]
if __name__ == "__main__":
pass
# data = LoadData()
# print(elimination(data.proteomic, data.clinical, RandomForestClassifier()))