2525from tide .meteo import get_oikolab_df
2626
2727
28- def _ensure_list ( item ) :
28+ class TideBaseMixin :
2929 """
30- Ensures the input is returned as a list.
31-
32- Parameters
33- ----------
34- item : any
35- The input item to be converted to a list if it is not already one.
36- If the input is `None`, an empty list is returned.
37-
38- Returns
39- -------
40- list
41- - If `item` is `None`, returns an empty list.
42- - If `item` is already a list, it is returned as is.
43- - Otherwise, wraps the `item` in a list and returns it.
44- """
45- if item is None :
46- return []
47- return item if isinstance (item , list ) else [item ]
48-
49-
50- class BaseProcessing (ABC , TransformerMixin , BaseEstimator ):
51- """
52- Abstract base class for processing pipelines with feature checks and
53- transformation logic.
54-
55- This class is designed to facilitate transformations by checking input data
56- (DataFrame or Series with DatetimeIndex), ensuring the presence
57- of required features, tracking added and removed features, and enabling
58- seamless integration with scikit-learn's API through fit and transform
59- methods.
30+ This class is designed to provide Tide base functionalities including :
31+ - checking features in and out
32+ - checking mandatory features
33+ - Modifying features names according to tide's tags
6034
6135 Parameters
6236 ----------
@@ -79,15 +53,6 @@ class BaseProcessing(ABC, TransformerMixin, BaseEstimator):
7953 columns.
8054 get_feature_names_in():
8155 Returns the names of the features as initially fitted.
82- fit(X, y=None):
83- Fits the transformer to the input data.
84- transform(X):
85- Applies the transformation to the input data.
86- _fit_implementation(X, y=None):
87- Abstract method for the fitting logic. Must be implemented by subclasses.
88- _transform_implementation(X):
89- Abstract method for the transformation logic. Must be implemented by
90- subclasses.
9156 """
9257
9358 def __init__ (
@@ -100,6 +65,15 @@ def __init__(
10065 self .removed_columns = removed_columns
10166 self .added_columns = added_columns
10267
68+ def check_required_features (self , X ):
69+ if self .required_columns is not None :
70+ if not set (self .required_columns ).issubset (X .columns ):
71+ raise ValueError ("One or several required columns are missing" )
72+
73+ def fit_check_features (self , X ):
74+ self .check_required_features (X )
75+ self .feature_names_in_ = list (X .columns )
76+
10377 def get_set_tags_values_columns (self , X , tag_level : int , value : str ):
10478 nb_tags = get_tag_levels (X .columns )
10579 if tag_level > nb_tags - 1 :
@@ -119,15 +93,6 @@ def get_set_tags_values_columns(self, X, tag_level: int, value: str):
11993 def set_tags_values (self , X , tag_level : int , value : str ):
12094 X .columns = self .get_set_tags_values_columns (X , tag_level , value )
12195
122- def check_features (self , X ):
123- if self .required_columns is not None :
124- if not set (self .required_columns ).issubset (X .columns ):
125- raise ValueError ("One or several required columns are missing" )
126-
127- def fit_check_features (self , X ):
128- self .check_features (X )
129- self .feature_names_in_ = list (X .columns )
130-
13196 def get_feature_names_out (self , input_features = None ):
13297 if input_features is None :
13398 check_is_fitted (self , attributes = ["feature_names_in_" ])
@@ -146,14 +111,73 @@ def get_feature_names_in(self):
146111 check_is_fitted (self , attributes = ["feature_names_in_" ])
147112 return self .feature_names_in_
148113
114+
115+ class BaseProcessing (ABC , TransformerMixin , BaseEstimator , TideBaseMixin ):
116+ """
117+ Abstract base class for processing pipelines with feature checks and
118+ transformation logic.
119+
120+ This class is designed to facilitate transformations by checking input data
121+ (DataFrame or Series with DatetimeIndex), ensuring the presence
122+ of required features, tracking added and removed features, and enabling
123+ seamless integration with scikit-learn's API through fit and transform
124+ methods.
125+
126+ Parameters
127+ ----------
128+ required_columns : str or list[str], optional
129+ Column names that must be present in the input data. Defaults to None.
130+ removed_columns : str or list[str], optional
131+ Column that will be removed during the transform process. Defaults to None.
132+ added_columns : str or list[str], optional
133+ Column that will be added to the output feature set during transform
134+ process. Defaults to None.
135+
136+ Methods
137+ -------
138+ check_features(X):
139+ Ensures that the required columns are present in the input DataFrame.
140+ fit_check_features(X):
141+ Checks required columns and stores the initial feature names.
142+ get_feature_names_out():
143+ Computes the final set of feature names, accounting for added and removed
144+ columns.
145+ get_feature_names_in():
146+ Returns the names of the features as initially fitted.
147+ fit(X, y=None):
148+ Fits the transformer to the input data.
149+ transform(X):
150+ Applies the transformation to the input data.
151+ _fit_implementation(X, y=None):
152+ Abstract method for the fitting logic. Must be implemented by subclasses.
153+ _transform_implementation(X):
154+ Abstract method for the transformation logic. Must be implemented by
155+ subclasses.
156+ """
157+
158+ def __init__ (
159+ self ,
160+ required_columns : str | list [str ] = None ,
161+ removed_columns : str | list [str ] = None ,
162+ added_columns : str | list [str ] = None ,
163+ ):
164+ TideBaseMixin .__init__ (
165+ self ,
166+ required_columns = required_columns ,
167+ removed_columns = removed_columns ,
168+ added_columns = added_columns ,
169+ )
170+ TransformerMixin .__init__ (self )
171+ BaseEstimator .__init__ (self )
172+
149173 def fit (self , X : pd .Series | pd .DataFrame , y = None ):
150174 X = check_and_return_dt_index_df (X )
151175 self .fit_check_features (X )
152176 self ._fit_implementation (X , y )
153177 return self
154178
155179 def transform (self , X : pd .Series | pd .DataFrame ):
156- self .check_features (X )
180+ self .check_required_features (X )
157181 X = check_and_return_dt_index_df (X )
158182 return self ._transform_implementation (X )
159183
0 commit comments