Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/api_reference/processing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ The processing module provides transformers for data processing and manipulation
:members:
:show-inheritance:

.. autoclass:: tide.processing.KeepColumns
:members:
:show-inheritance:

.. autoclass:: tide.processing.ReplaceTag
:members:
:show-inheritance:
Expand Down
27 changes: 26 additions & 1 deletion tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ProjectSolarRadOnSurfaces,
FillOtherColumns,
DropColumns,
KeepColumns,
ReplaceTag,
AddFourierPairs,
)
Expand Down Expand Up @@ -962,7 +963,7 @@ def test_drop_columns(self):
index=pd.date_range("2009", freq="h", periods=2, tz="UTC"),
)

col_dropper = DropColumns()
col_dropper = KeepColumns()
col_dropper.fit(df)
res = col_dropper.transform(df.copy())
pd.testing.assert_frame_equal(df, res)
Expand All @@ -980,6 +981,30 @@ def test_drop_columns(self):
assert res.shape == (2, 0)
check_feature_names_out(col_dropper, res)

def test_keep_columns(self):
df = pd.DataFrame(
{"a": [1, 2], "b": [1, 2], "c": [1, 2]},
index=pd.date_range("2009", freq="h", periods=2, tz="UTC"),
)

col_keeper = KeepColumns()
col_keeper.fit(df)
res = col_keeper.transform(df.copy())
pd.testing.assert_frame_equal(df, res)
check_feature_names_out(col_keeper, res)

col_keeper = KeepColumns(columns="a")
col_keeper.fit(df)
res = col_keeper.transform(df.copy())
pd.testing.assert_frame_equal(df[["a"]], res)
check_feature_names_out(col_keeper, res)

col_keeper = KeepColumns(columns=["a|b", "c"])
col_keeper.fit(df)
res = col_keeper.transform(df.copy())
pd.testing.assert_frame_equal(df, res)
check_feature_names_out(col_keeper, res)

def test_replace_tag(self):
df = pd.DataFrame(
{"energy_1__Wh": [1.0, 2.0], "energy_2__Whr__bloc": [3.0, 4.0]},
Expand Down
81 changes: 81 additions & 0 deletions tide/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,87 @@ def _transform_implementation(self, X: pd.Series | pd.DataFrame):
)


class KeepColumns(BaseProcessing):
"""
A transformer that keeps specified columns from a pandas DataFrame.

It is particularly useful at the final step of data preprocessing.
When only some columns are passed to a model

Parameters
----------
columns : str | list[str], optional (default=None)
The column name or a list of column names to be dropped.
If None, no columns are dropped and the DataFrame is returned unchanged.
Example: 'temp__°C' or ['temp__°C', 'humid__%'] or '°C|%'

Attributes
----------
feature_names_in_ : list[str]
Names of input columns (set during fit).
feature_names_out_ : list[str]
Names of output columns (input columns minus dropped columns).

Examples
--------
>>> import pandas as pd
>>> # Create DataFrame with DateTimeIndex
>>> dates = pd.date_range(
... start="2024-01-01 00:00:00", end="2024-01-01 00:02:00", freq="1min"
... ).tz_localize("UTC")
>>> df = pd.DataFrame(
... {
... "temp__°C": [20, 21, 22],
... "humid__%": [45, 50, 55],
... "press__Pa": [1000, 1010, 1020],
... },
... index=dates,
... )
>>> # Keep a single column
>>> keeper = KeepColumns(columns="temp__°C")
>>> result = keeper.fit_transform(df)
>>> print(result)
temp__°C
2024-01-01 00:00:00+00:00 20
2024-01-01 00:01:00+00:00 21
2024-01-01 00:02:00+00:00 22
>>> # Keep multiple columns
>>> keeper_multi = KeepColumns(columns="°C|%")
>>> result_multi = keeper_multi.fit_transform(df)
>>> print(result_multi)
temp__°C humid__%
2024-01-01 00:00:00+00:00 20 45
2024-01-01 00:01:00+00:00 21 50
2024-01-01 00:02:00+00:00 22 55

Notes
-----
- If a specified column doesn't exist in the DataFrame, it will be silently
ignored
- The order of selected columns is preserved
- If no columns are specified (columns=None), the DataFrame is returned
unchanged

Returns
-------
pd.DataFrame
The DataFrame with specified columns removed. The output maintains
the same DateTimeIndex as the input, with only the specified columns
removed.
"""

def __init__(self, columns: str | list[str] = None):
self.columns = columns
BaseProcessing.__init__(self)

def _fit_implementation(self, X: pd.Series | pd.DataFrame, y=None):
self.required_columns = tide_request(X, self.columns)
self.feature_names_out_ = self.required_columns

def _transform_implementation(self, X: pd.Series | pd.DataFrame):
return X[self.feature_names_out_]


class ReplaceTag(BaseProcessing):
"""A transformer that replaces components of Tide tag names with new values.

Expand Down