-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtransformers.py
More file actions
52 lines (38 loc) · 1.71 KB
/
transformers.py
File metadata and controls
52 lines (38 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# Custom sklearn transformers
from sklearn.base import BaseEstimator, TransformerMixin
class MismatchedColumnsHandler(BaseEstimator, TransformerMixin):
"""
Add missing columns with default values and remove extra columns.
Note: Almost a SimpleImputer but with the ability to remove extra columns.
"""
def __init__(self, default_values):
"""
Will add missing columns with default values and remove extra columns
Args:
input_cols: list / array of columns in df to be inputed into model
default_values: dictionary with default values for columns the model expects
"""
self.default_values = default_values
def fit(self, X = None, y = None):
return self
def transform(self, X):
if set(self.default_values.keys()) == set(X.columns):
return X
X = X.copy() # Avoid changing the original dataframe
# Missing columns: Expected by the model but not in input
missing_cols = list(set(self.default_values.keys()) - set(X.columns))
# Add missing columns with default values
for col in missing_cols:
X[col] = self.default_values[col]
# Just keep the columns that the model expects
X = X[self.default_values.keys()]
return X
if __name__ == '__main__':
import pandas as pd
import numpy as np
# Test MismatchedColumnsHandler
t = MismatchedColumnsHandler(default_values={'2': -1, '3': -1, '4': 4})
X = pd.DataFrame({'1': [1], '2': [2], '3': [3]})
expected = pd.DataFrame({'2': [2], '3': [3], '4': [4]})
print(t.transform(X))
assert t.transform(X).equals(expected), 'MismatchedColumnsHandler failed'