diff --git a/docs/api_reference/processing.rst b/docs/api_reference/processing.rst index bfb940f..7e85222 100644 --- a/docs/api_reference/processing.rst +++ b/docs/api_reference/processing.rst @@ -87,6 +87,10 @@ The processing module provides transformers for data processing and manipulation :members: :show-inheritance: +.. autoclass:: tide.processing.KeepColumns + :members: + :show-inheritance: + .. autoclass:: tide.processing.ReplaceTag :members: :show-inheritance: diff --git a/tests/test_processing.py b/tests/test_processing.py index 39dbbb2..337fe28 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -35,6 +35,7 @@ ProjectSolarRadOnSurfaces, FillOtherColumns, DropColumns, + KeepColumns, ReplaceTag, AddFourierPairs, ) @@ -962,7 +963,7 @@ def test_drop_columns(self): index=pd.date_range("2009", freq="h", periods=2, tz="UTC"), ) - col_dropper = DropColumns() + col_dropper = KeepColumns() col_dropper.fit(df) res = col_dropper.transform(df.copy()) pd.testing.assert_frame_equal(df, res) @@ -980,6 +981,30 @@ def test_drop_columns(self): assert res.shape == (2, 0) check_feature_names_out(col_dropper, res) + def test_keep_columns(self): + df = pd.DataFrame( + {"a": [1, 2], "b": [1, 2], "c": [1, 2]}, + index=pd.date_range("2009", freq="h", periods=2, tz="UTC"), + ) + + col_keeper = KeepColumns() + col_keeper.fit(df) + res = col_keeper.transform(df.copy()) + pd.testing.assert_frame_equal(df, res) + check_feature_names_out(col_keeper, res) + + col_keeper = KeepColumns(columns="a") + col_keeper.fit(df) + res = col_keeper.transform(df.copy()) + pd.testing.assert_frame_equal(df[["a"]], res) + check_feature_names_out(col_keeper, res) + + col_keeper = KeepColumns(columns=["a|b", "c"]) + col_keeper.fit(df) + res = col_keeper.transform(df.copy()) + pd.testing.assert_frame_equal(df, res) + check_feature_names_out(col_keeper, res) + def test_replace_tag(self): df = pd.DataFrame( {"energy_1__Wh": [1.0, 2.0], "energy_2__Whr__bloc": [3.0, 4.0]}, diff --git a/tide/processing.py b/tide/processing.py index a4660ab..a4d6124 100644 --- a/tide/processing.py +++ b/tide/processing.py @@ -2714,6 +2714,87 @@ def _transform_implementation(self, X: pd.Series | pd.DataFrame): ) +class KeepColumns(BaseProcessing): + """ + A transformer that keeps specified columns from a pandas DataFrame. + + It is particularly useful at the final step of data preprocessing. + When only some columns are passed to a model + + Parameters + ---------- + columns : str | list[str], optional (default=None) + The column name or a list of column names to be dropped. + If None, no columns are dropped and the DataFrame is returned unchanged. + Example: 'temp__°C' or ['temp__°C', 'humid__%'] or '°C|%' + + Attributes + ---------- + feature_names_in_ : list[str] + Names of input columns (set during fit). + feature_names_out_ : list[str] + Names of output columns (input columns minus dropped columns). + + Examples + -------- + >>> import pandas as pd + >>> # Create DataFrame with DateTimeIndex + >>> dates = pd.date_range( + ... start="2024-01-01 00:00:00", end="2024-01-01 00:02:00", freq="1min" + ... ).tz_localize("UTC") + >>> df = pd.DataFrame( + ... { + ... "temp__°C": [20, 21, 22], + ... "humid__%": [45, 50, 55], + ... "press__Pa": [1000, 1010, 1020], + ... }, + ... index=dates, + ... ) + >>> # Keep a single column + >>> keeper = KeepColumns(columns="temp__°C") + >>> result = keeper.fit_transform(df) + >>> print(result) + temp__°C + 2024-01-01 00:00:00+00:00 20 + 2024-01-01 00:01:00+00:00 21 + 2024-01-01 00:02:00+00:00 22 + >>> # Keep multiple columns + >>> keeper_multi = KeepColumns(columns="°C|%") + >>> result_multi = keeper_multi.fit_transform(df) + >>> print(result_multi) + temp__°C humid__% + 2024-01-01 00:00:00+00:00 20 45 + 2024-01-01 00:01:00+00:00 21 50 + 2024-01-01 00:02:00+00:00 22 55 + + Notes + ----- + - If a specified column doesn't exist in the DataFrame, it will be silently + ignored + - The order of selected columns is preserved + - If no columns are specified (columns=None), the DataFrame is returned + unchanged + + Returns + ------- + pd.DataFrame + The DataFrame with specified columns removed. The output maintains + the same DateTimeIndex as the input, with only the specified columns + removed. + """ + + def __init__(self, columns: str | list[str] = None): + self.columns = columns + BaseProcessing.__init__(self) + + def _fit_implementation(self, X: pd.Series | pd.DataFrame, y=None): + self.required_columns = tide_request(X, self.columns) + self.feature_names_out_ = self.required_columns + + def _transform_implementation(self, X: pd.Series | pd.DataFrame): + return X[self.feature_names_out_] + + class ReplaceTag(BaseProcessing): """A transformer that replaces components of Tide tag names with new values.