Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions tests/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
AddTimeLag,
ApplyExpression,
Resample,
ColumnsCombine,
CombineColumns,
ReplaceThreshold,
DropTimeGradient,
Dropna,
Expand All @@ -34,6 +34,7 @@
AddSolarAngles,
ProjectSolarRadOnSurfaces,
FillOtherColumns,
DropColumns,
)

RESOURCES_PATH = Path(__file__).parent / "resources"
Expand Down Expand Up @@ -458,7 +459,7 @@ def test_pd_combine_columns(self):
index=pd.date_range("2009", freq="h", periods=2),
)

trans = ColumnsCombine(
trans = CombineColumns(
function=np.sum,
columns=["a__°C", "b__°C"],
function_kwargs={"axis": 1},
Expand All @@ -481,7 +482,7 @@ def test_pd_combine_columns(self):

pd.testing.assert_frame_equal(trans.fit_transform(x_in), ref)

trans = ColumnsCombine(
trans = CombineColumns(
function=np.sum,
tide_format_columns="°C",
function_kwargs={"axis": 1},
Expand Down Expand Up @@ -820,3 +821,27 @@ def test_fill_other_columns(self):
np.isnan(res["col_1"])
== np.isnan([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, np.nan, np.nan])
)

def test_drop_columns(self):
df = pd.DataFrame(
{"a": [1, 2], "b": [1, 2], "c": [1, 2]},
index=pd.date_range("2009", freq="h", periods=2),
)

col_dropper = DropColumns()
col_dropper.fit(df)
res = col_dropper.transform(df.copy())

pd.testing.assert_frame_equal(df, res)

col_dropper = DropColumns(columns="a")
col_dropper.fit(df)
res = col_dropper.transform(df.copy())

pd.testing.assert_frame_equal(df[["b", "c"]], res)

col_dropper = DropColumns(columns=["a", "b", "c"])
col_dropper.fit(df)
res = col_dropper.transform(df.copy())

assert res.shape == (2, 0)
30 changes: 29 additions & 1 deletion tide/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def _transform_implementation(self, X: pd.Series | pd.DataFrame):
return X.apply(gauss_filter)


class ColumnsCombine(BaseProcessing):
class CombineColumns(BaseProcessing):
"""
A class that combines multiple columns in a pandas DataFrame using a specified
function.
Expand Down Expand Up @@ -1576,3 +1576,31 @@ def _transform_implementation(self, X: pd.Series | pd.DataFrame):
if self.drop_filling_columns
else X
)


class DropColumns(BaseProcessing):
"""
Drop specified columns.

Parameters
----------
columns : str or list[str], optional
The column name or a list of column names to be dropped.
If None, no columns are dropped.

"""

def __init__(self, columns: str | list[str] = None):
self.columns = columns
BaseProcessing.__init__(self)

def _fit_implementation(self, X: pd.Series | pd.DataFrame, y=None):
self.required_columns = self.columns
self.removed_columns = self.columns

def _transform_implementation(self, X: pd.Series | pd.DataFrame):
return (
X.drop(self.removed_columns, axis="columns")
if self.columns is not None
else X
)
Loading