Skip to content

Commit f38af64

Browse files
authored
Merge pull request #35 from BuildingEnergySimulationTools/25-improve-select-function
25 improve select function
2 parents 5584ab2 + f62b3ec commit f38af64

File tree

8 files changed

+167
-54
lines changed

8 files changed

+167
-54
lines changed

docs/api_reference/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ API Reference
44
.. toctree::
55
:maxdepth: 2
66

7+
utils
78
plumbing
89
processing
910
regressor

docs/api_reference/utils.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Utils Modules
2+
===============
3+
4+
Tides utility functions and class.
5+
Mostly for handling tags, generating tree, or finding and selecting data gaps.
6+
7+
.. autofunction:: tide.utils.tide_request

tests/test_processing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def test_replace_tag(self):
993993
def test_add_fourier_pairs(self):
994994
test_df = pd.DataFrame(
995995
data=np.arange(24).astype("float64"),
996-
index=pd.date_range("2009-01-01 00:00:00", freq="H", periods=24, tz="UTC"),
996+
index=pd.date_range("2009-01-01 00:00:00", freq="h", periods=24, tz="UTC"),
997997
columns=["feat_1"],
998998
)
999999

@@ -1036,14 +1036,14 @@ def test_add_fourier_pairs(self):
10361036
"1 days 00:00:00_order_2_Sine",
10371037
"1 days 00:00:00_order_2_Cosine",
10381038
],
1039-
index=pd.date_range("2009-01-01 00:00:00", freq="H", periods=24, tz="UTC"),
1039+
index=pd.date_range("2009-01-01 00:00:00", freq="h", periods=24, tz="UTC"),
10401040
)
10411041

10421042
pd.testing.assert_frame_equal(res, ref_df)
10431043

10441044
test_df_phi = pd.DataFrame(
10451045
data=np.arange(24),
1046-
index=pd.date_range("2009-01-01 06:00:00", freq="H", periods=24),
1046+
index=pd.date_range("2009-01-01 06:00:00", freq="h", periods=24),
10471047
columns=["feat_1"],
10481048
)
10491049
test_df_phi = test_df_phi.tz_localize("UTC")
@@ -1053,7 +1053,7 @@ def test_add_fourier_pairs(self):
10531053

10541054
test_df = pd.DataFrame(
10551055
data=np.arange(24).astype("float64"),
1056-
index=pd.date_range("2009-01-01 00:00:00", freq="H", periods=24, tz="UTC"),
1056+
index=pd.date_range("2009-01-01 00:00:00", freq="h", periods=24, tz="UTC"),
10571057
columns=["feat_1__°C__building__room"],
10581058
)
10591059

tests/test_utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
data_columns_to_tree,
1111
get_data_col_names_from_root,
1212
get_data_level_values,
13-
parse_request_to_col_names,
13+
tide_request,
1414
timedelta_to_int,
1515
NamedList,
1616
_get_series_bloc,
@@ -58,7 +58,7 @@ def test_columns_parser(self):
5858
assert all(col in DF_COLUMNS.columns for col in col_names)
5959

6060
def test_parse_request_to_col_names(self):
61-
res = parse_request_to_col_names(DF_COLUMNS)
61+
res = tide_request(DF_COLUMNS)
6262
assert res == [
6363
"name_1__°C__bloc1",
6464
"name_1__°C__bloc2",
@@ -69,10 +69,13 @@ def test_parse_request_to_col_names(self):
6969
"name4__DIMENSIONLESS__bloc4",
7070
]
7171

72-
res = parse_request_to_col_names(DF_COLUMNS, "name_1__°C__bloc1")
72+
res = tide_request(DF_COLUMNS, "name_1__°C__bloc1")
7373
assert res == ["name_1__°C__bloc1"]
7474

75-
res = parse_request_to_col_names(
75+
res = tide_request(DF_COLUMNS, ["name_1__°C__bloc1"])
76+
assert res == ["name_1__°C__bloc1"]
77+
78+
res = tide_request(
7679
DF_COLUMNS,
7780
[
7881
"name_1__°C__bloc1",
@@ -84,18 +87,28 @@ def test_parse_request_to_col_names(self):
8487
"name_1__°C__bloc2",
8588
]
8689

87-
res = parse_request_to_col_names(DF_COLUMNS, "°C")
90+
res = tide_request(DF_COLUMNS, "°C")
8891
assert res == ["name_1__°C__bloc1", "name_1__°C__bloc2"]
8992

90-
res = parse_request_to_col_names(DF_COLUMNS, "OTHER")
93+
res = tide_request(DF_COLUMNS, "OTHER")
9194
assert res == ["name_2", "name_3__kWh/m²", "name_5__kWh"]
9295

93-
res = parse_request_to_col_names(DF_COLUMNS, "DIMENSIONLESS__bloc2")
96+
res = tide_request(DF_COLUMNS, "DIMENSIONLESS__bloc2")
9497
assert res == ["name_2__DIMENSIONLESS__bloc2"]
9598

96-
res = parse_request_to_col_names(DF_COLUMNS, "kWh")
99+
res = tide_request(DF_COLUMNS, "kWh")
97100
assert res == ["name_5__kWh"]
98101

102+
res = tide_request(DF_COLUMNS, "kWh|°C")
103+
assert res == ["name_5__kWh", "name_1__°C__bloc1", "name_1__°C__bloc2"]
104+
105+
res = tide_request(DF_COLUMNS, ["kWh|°C", "name_5__kWh"])
106+
assert res == [
107+
"name_5__kWh",
108+
"name_1__°C__bloc1",
109+
"name_1__°C__bloc2",
110+
]
111+
99112
def test_get_data_level_names(self):
100113
root = data_columns_to_tree(DF_COLUMNS.columns)
101114
res = get_data_level_values(root, "name")

tide/plot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from tide.utils import (
88
check_and_return_dt_index_df,
9-
parse_request_to_col_names,
9+
tide_request,
1010
data_columns_to_tree,
1111
get_data_level_values,
1212
get_data_blocks,
@@ -63,7 +63,7 @@ def get_cols_axis_maps_and_labels(
6363
col_axes_map = {}
6464
axes_col_map = {}
6565
for i, tag in enumerate(y_tags):
66-
selected_cols = parse_request_to_col_names(columns, tag)
66+
selected_cols = tide_request(columns, tag)
6767
axes_col_map["y" if i == 0 else f"y{i + 1}"] = selected_cols
6868
for col in selected_cols:
6969
col_axes_map[col] = {"yaxis": "y"} if i == 0 else {"yaxis": f"y{i + 1}"}

tide/plumbing.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sklearn.compose import ColumnTransformer
99

1010
from tide.utils import (
11-
parse_request_to_col_names,
11+
tide_request,
1212
check_and_return_dt_index_df,
1313
data_columns_to_tree,
1414
get_data_level_values,
@@ -62,7 +62,7 @@ def _get_column_wise_transformer(
6262
) -> ColumnTransformer | None:
6363
col_trans_list = []
6464
for req, proc_list in proc_dict.items():
65-
requested_col = parse_request_to_col_names(data_columns, req)
65+
requested_col = tide_request(data_columns, req)
6666
if not requested_col:
6767
pass
6868
else:
@@ -358,7 +358,7 @@ def select(
358358
pd.Index
359359
Selected column names
360360
"""
361-
return parse_request_to_col_names(self.data, select)
361+
return tide_request(self.data, select)
362362

363363
def get_pipeline(
364364
self,
@@ -438,7 +438,7 @@ def get_pipeline(
438438
"""
439439
if self.data is None:
440440
raise ValueError("data is required to build a pipeline")
441-
selection = parse_request_to_col_names(self.data, select)
441+
selection = tide_request(self.data, select)
442442
if steps is None or self.pipe_dict is None:
443443
dict_to_pipe = None
444444
else:
@@ -541,7 +541,7 @@ def get_corrected_data(
541541
"""
542542
if self.data is None:
543543
raise ValueError("Cannot get corrected data. data are missing")
544-
select = parse_request_to_col_names(self.data, select)
544+
select = tide_request(self.data, select)
545545
data = self.data.loc[
546546
start or self.data.index[0] : stop or self.data.index[-1], select
547547
].copy()
@@ -834,9 +834,7 @@ def plot(
834834
# for example) So we just process the whole data hoping to find the result
835835
# after.
836836
select_corr = (
837-
self.data.columns
838-
if not parse_request_to_col_names(self.data, select)
839-
else select
837+
self.data.columns if not tide_request(self.data, select) else select
840838
)
841839

842840
data_1 = self.get_corrected_data(select_corr, start, stop, steps, verbose)

tide/processing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
get_data_blocks,
1414
get_outer_timestamps,
1515
check_and_return_dt_index_df,
16-
parse_request_to_col_names,
16+
tide_request,
1717
ensure_list,
1818
)
1919
from tide.regressors import SkSTLForecast, SkProphet
@@ -1269,9 +1269,7 @@ def _fit_implementation(self, X: pd.Series | pd.DataFrame, y=None):
12691269
if self.tide_format_methods:
12701270
self.columns_methods = []
12711271
for req, method in self.tide_format_methods.items():
1272-
self.columns_methods.append(
1273-
(parse_request_to_col_names(X.columns, req), method)
1274-
)
1272+
self.columns_methods.append((tide_request(X.columns, req), method))
12751273

12761274
return self
12771275

tide/utils.py

Lines changed: 124 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -143,40 +143,136 @@ def get_data_col_names_from_root(data_root):
143143
][-1]
144144

145145

146-
def parse_request_to_col_names(
146+
def find_cols_with_tide_tags(
147+
data_columns: pd.Index | list[str], request: str
148+
) -> list[str]:
149+
request_parts = request.split("__")
150+
151+
if not (1 <= len(request_parts) <= 4):
152+
raise ValueError(
153+
f"Request '{request}' is malformed. "
154+
f"Use 'name__unit__bloc__sub_bloc' format or a "
155+
f"combination of these tags."
156+
)
157+
158+
full_tag_col_map = {
159+
col_name_tag_enrichment(col, get_tags_max_level(data_columns)): col
160+
for col in data_columns
161+
}
162+
163+
def find_exact_match(search_str, target):
164+
pattern = rf"(?:^|__)(?:{re.escape(search_str)})(?:$|__)"
165+
match = re.search(pattern, target)
166+
return match is not None
167+
168+
return [
169+
full_tag_col_map[augmented_col]
170+
for augmented_col in full_tag_col_map.keys()
171+
if all(find_exact_match(part, augmented_col) for part in request_parts)
172+
]
173+
174+
175+
def find_cols_multiple_tag_groups(
176+
data_columns: pd.Index | list[str], request: str
177+
) -> list[str]:
178+
request_parts = request.split("|")
179+
list_to_return = []
180+
for req in request_parts:
181+
list_to_return.extend(find_cols_with_tide_tags(data_columns, req))
182+
return list_to_return
183+
184+
185+
def tide_request(
147186
data_columns: pd.Index | list[str], request: str | pd.Index | list[str] = None
148187
) -> list[str]:
188+
"""
189+
Select columns by matching structured TIDE-style tags.
190+
191+
Filters column names based on a TIDE-style structured tag syntax. Columns are
192+
expected to use a naming convention with double underscores (`__`) separating
193+
tags.
194+
195+
A column name can include up to four hierarchical parts:
196+
'name__unit__bloc__sub_bloc' where each part is optional, but must be separated
197+
with double underscores.
198+
199+
The `request` argument allows searching for columns matching one or more
200+
of these parts using full or partial tag patterns. Multiple tag patterns
201+
can be combined using the pipe `|` character to form OR conditions.
202+
203+
Parameters
204+
----------
205+
data_columns : pandas.Index or list of str
206+
A collection of column names to filter. Each column name should follow
207+
the TIDE format (e.g., "sensor__°C__bloc1").
208+
209+
request : str or list of str or pandas.Index, optional
210+
Tag(s) to match against the column names. Each tag string may be:
211+
212+
- A full structured tag (e.g., "name__°C__bloc2")
213+
- A partial tag (e.g., "°C", "bloc1")
214+
- A group of tags separated by "|" (e.g., "kWh|°C")
215+
216+
If None, all columns from `data_columns` are returned.
217+
218+
Returns
219+
-------
220+
list of str
221+
The list of column names that match any of the provided tag queries.
222+
223+
Notes
224+
-----
225+
- Matching is done per tag part, not substrings. For instance, the query
226+
"bloc1" will match "name__°C__bloc1" but not "bloc11".
227+
- If multiple requests are given, columns are returned if they match
228+
at least one of them (logical OR).
229+
- Tags can include between 1 and 4 parts, split by `__`.
230+
231+
Examples
232+
--------
233+
>>> DF_COLUMNS = [
234+
... "name_1__°C__bloc1",
235+
... "name_1__°C__bloc2",
236+
... "name_2",
237+
... "name_2__DIMENSIONLESS__bloc2",
238+
... "name_3__kWh/m²",
239+
... "name_5__kWh",
240+
... "name4__DIMENSIONLESS__bloc4",
241+
... ]
242+
243+
>>> tide_request(DF_COLUMNS)
244+
['name_1__°C__bloc1', 'name_1__°C__bloc2', 'name_2',
245+
'name_2__DIMENSIONLESS__bloc2', 'name_3__kWh/m²',
246+
'name_5__kWh', 'name4__DIMENSIONLESS__bloc4']
247+
248+
>>> tide_request(DF_COLUMNS, "°C")
249+
['name_1__°C__bloc1', 'name_1__°C__bloc2']
250+
251+
>>> tide_request(DF_COLUMNS, "kWh|°C")
252+
['name_5__kWh', 'name_1__°C__bloc1', 'name_1__°C__bloc2']
253+
254+
>>> # Columns are not selected twice
255+
>>> tide_request(DF_COLUMNS, ["kWh|°C", "name_5__kWh"])
256+
['name_5__kWh', 'name_1__°C__bloc1', 'name_1__°C__bloc2']
257+
"""
258+
149259
if request is None:
150260
return list(data_columns)
151261

152-
elif isinstance(request, pd.Index) or isinstance(request, list):
153-
return [col for col in request if col in data_columns]
262+
elif isinstance(request, str):
263+
request = [request]
154264

155-
else:
156-
request_parts = request.split("__")
157-
158-
if not (1 <= len(request_parts) <= 4):
159-
raise ValueError(
160-
f"Request '{request}' is malformed. "
161-
f"Use 'name__unit__bloc__sub_bloc' format or a "
162-
f"combination of these tags."
163-
)
164-
165-
full_tag_col_map = {
166-
col_name_tag_enrichment(col, get_tags_max_level(data_columns)): col
167-
for col in data_columns
168-
}
169-
170-
def find_exact_match(search_str, target):
171-
pattern = rf"(?:^|__)(?:{re.escape(search_str)})(?:$|__)"
172-
match = re.search(pattern, target)
173-
return match is not None
174-
175-
return [
176-
full_tag_col_map[augmented_col]
177-
for augmented_col in full_tag_col_map.keys()
178-
if all(find_exact_match(part, augmented_col) for part in request_parts)
179-
]
265+
if not (isinstance(request, pd.Index) or isinstance(request, list)):
266+
raise ValueError(
267+
"Invalid request. Was expected an instance of str, pd.Index or List[str]"
268+
f"got {type(request)} instead"
269+
)
270+
271+
list_to_return = []
272+
for req in request:
273+
list_to_return.extend(find_cols_multiple_tag_groups(data_columns, req))
274+
275+
return list(dict.fromkeys(list_to_return))
180276

181277

182278
def data_columns_to_tree(columns: pd.Index | list[str]) -> T:

0 commit comments

Comments
 (0)