From dd606a1e726c6a49a68dbe83f6b837be411a49b4 Mon Sep 17 00:00:00 2001 From: samland1116 Date: Fri, 5 Dec 2025 13:43:28 -0600 Subject: [PATCH 1/7] adds rcf ForecastLeadTimeBins, corrects misc. formatting --- .../models/calculated_fields/row_level.py | 273 +++++++++++++++++- .../calculated_fields/timeseries_aware.py | 1 + tests/evaluations/test_add_udfs.py | 9 +- 3 files changed, 276 insertions(+), 7 deletions(-) diff --git a/src/teehr/models/calculated_fields/row_level.py b/src/teehr/models/calculated_fields/row_level.py index 5252a155..a0fc5bcd 100644 --- a/src/teehr/models/calculated_fields/row_level.py +++ b/src/teehr/models/calculated_fields/row_level.py @@ -1,8 +1,7 @@ """Classes representing UDFs.""" import calendar -from typing import List, Union -from pydantic import BaseModel as PydanticBaseModel -from pydantic import Field, ConfigDict +from typing import Union +from pydantic import Field import pandas as pd import pyspark.sql.types as T from pyspark.sql.functions import pandas_udf @@ -202,7 +201,8 @@ def apply_to(self, sdf: ps.DataFrame) -> ps.DataFrame: def func(value_time: pd.Series) -> pd.Series: return value_time.dt.month.apply( lambda x: next( - (season for season, months in self.season_months.items() if x in months), + (season for season, + months in self.season_months.items() if x in months), None ) ) @@ -215,7 +215,7 @@ def func(value_time: pd.Series) -> pd.Series: class ForecastLeadTime(CalculatedFieldABC, CalculatedFieldBaseModel): - """Adds the forecast lead time in seconds from a timestamp column. + """Adds the forecast lead time from a timestamp column. Properties ---------- @@ -257,6 +257,266 @@ def func(value_time: pd.Series, return sdf +class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): + """Adds ID for grouped forecast lead time bins. + + Properties + ---------- + - value_time_field_name: + The name of the column containing the timestamp. + Default: "value_time" + - reference_time_field_name: + The name of the column containing the forecast time. + Default: "reference_time" + - lead_time_field_name: + The name of the column containing the forecast lead time. + Default: "forecast_lead_time" + - output_field_name: + The name of the column to store the lead time bin ID. + Default: "forecast_lead_time_bin" + - bin_size: + Either a single pd.Timedelta for uniform bin sizes, or a dict mapping + threshold pd.Timedelta values to pd.Timedelta bin sizes. Keys + represent the upper bound of each threshold range. Keys can also be + string or pd.Timestamp values, which will be converted to pd.Timedelta + relative to the minimum value_time in the DataFrame. + Default: 5 days + + Example dict 1: { + pd.Timedelta(days=1): pd.Timedelta(hours=6), + pd.Timedelta(days=2): pd.Timedelta(hours=12), + pd.Timedelta(days=3): pd.Timedelta(days=1) + } + Example dict 2: { + "2024-11-20 12:00:00": pd.Timedelta(hours=6), + "2024-11-21 12:00:00": pd.Timedelta(hours=12), + "2024-11-22 12:00:00": pd.Timedelta(days=1) + } + Example dict 3: { + pd.Timestamp("2024-11-20 12:00:00"): pd.Timedelta(hours=6), + pd.Timestamp("2024-11-21 12:00:00"): pd.Timedelta(hours=12), + pd.Timestamp("2024-11-22 12:00:00"): pd.Timedelta(days=1) + } + """ + + value_time_field_name: str = Field( + default="value_time" + ) + reference_time_field_name: str = Field( + default="reference_time" + ) + lead_time_field_name: str = Field( + default="forecast_lead_time" + ) + output_field_name: str = Field( + default="forecast_lead_time_bin" + ) + bin_size: Union[pd.Timedelta, dict] = Field( + default=pd.Timedelta(days=5) + ) + + @staticmethod + def validate_bin_size_dict( + self, + sdf: ps.DataFrame + ) -> Union[pd.Timedelta, dict]: + """Validate and normalize bin_size dict. + + Validates that bin_size dict keys are of correct, uniform type and + converts them to pd.Timedelta if needed. Validates that bin_size dict + values are of type pd.Timedelta. Validates that dict is not empty. + + If bin_size is already a pd.Timedelta or dict with pd.Timedelta keys, + returns it unchanged. Otherwise, converts string/datetime keys to + pd.Timedelta by calculating the difference from the minimum value_time. + """ + # If not a dict or already has Timedelta keys, return as-is + if not isinstance(self.bin_size, dict) and isinstance( + self.bin_size, pd.Timedelta + ): + return self.bin_size + + # raise error if dict is empty + if not self.bin_size: + raise ValueError("bin_size dict cannot be empty") + + # Ensure all keys are of the same type + first_key = next(iter(self.bin_size.keys())) + for key in self.bin_size.keys(): + if not isinstance(key, type(first_key)): + raise TypeError( + "All bin_size dict keys must be of the same type" + ) + + # Ensure all values are pd.Timedelta + for value in self.bin_size.values(): + if not isinstance(value, pd.Timedelta): + raise TypeError( + "All bin_size dict values must be of type pd.Timedelta" + ) + + # If already Timedelta keys, return as-is + if isinstance(first_key, pd.Timedelta): + return self.bin_size + + # Convert string or datetime keys to Timedelta + if isinstance(first_key, (str, pd.Timestamp)): + # Get minimum value_time from the dataframe + min_value_time_row = sdf.select( + self.value_time_field_name + ).orderBy(self.value_time_field_name).first() + + # Extract value_time and convert to datetime + min_value_time = pd.to_datetime( + min_value_time_row[self.value_time_field_name] + ) + + # Convert dict keys from string/datetime to Timedelta + converted_dict = {} + for key, value in self.bin_size.items(): + # Convert key to datetime if it's a string + if isinstance(key, str): + try: + key_datetime = pd.to_datetime(key) + except Exception as e: + raise ValueError( + f"Error converting key '{key}' to datetime: {e}" + ) + else: + key_datetime = key + + # Calculate Timedelta from min_value_time + timedelta_key = key_datetime - min_value_time + converted_dict[timedelta_key] = value + + return converted_dict + + else: + # raise error for unsupported key types + raise TypeError( + "bin_size dict keys must be pd.Timedelta, str, or pd.Timestamp" + ) + + @staticmethod + def add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: + """Calculate forecast lead time if not already present.""" + if self.lead_time_field_name not in sdf.columns: + flt_cf = ForecastLeadTime( + value_time_field_name=self.value_time_field_name, + reference_time_field_name=self.reference_time_field_name, + output_field_name=self.lead_time_field_name + ) + sdf = flt_cf.apply_to(sdf) + return sdf + + @staticmethod + def add_forecast_lead_time_bin( + self, + sdf: ps.DataFrame + ) -> ps.DataFrame: + """Add forecast lead time bin column.""" + + @pandas_udf(returnType=T.StringType()) + def func(lead_time: pd.Series, value_time: pd.Series) -> pd.Series: + if isinstance(self.bin_size, dict): + # Sort thresholds for consistent processing + sorted_thresholds = sorted(self.bin_size.items(), + key=lambda x: x[0].total_seconds()) + + bin_ids = pd.Series("", index=lead_time.index) + prev_thresh_sec = 0 + + for i, (threshold, bin_size) in enumerate(sorted_thresholds): + threshold_seconds = threshold.total_seconds() + bin_size_seconds = bin_size.total_seconds() + + is_last = (i == len(sorted_thresholds) - 1) + + # Mask for values in this threshold range + if is_last: + mask = lead_time.dt.total_seconds() >= prev_thresh_sec + else: + mask = ( + lead_time.dt.total_seconds() >= prev_thresh_sec + ) & ( + lead_time.dt.total_seconds() < threshold_seconds + ) + + if mask.any(): + # Calculate which bin each value belongs to + bins_in_range = ( + (lead_time[mask].dt.total_seconds() - + prev_thresh_sec) + // bin_size_seconds + ).astype(int) + + # create bin ID from actual timestamps + for bin_num in bins_in_range.unique(): + bin_mask = mask & (bins_in_range == bin_num) + + if bin_mask.any(): + # Get the unique value_time values for this bin + bin_value_times = value_time[bin_mask].unique() + + # Convert to datetime and find min/max + bin_value_times_dt = pd.to_datetime( + bin_value_times + ) + start_timestamp = bin_value_times_dt.min() + end_timestamp = start_timestamp + bin_size + + # Format as ISO 8601 timestamp range + bin_id = f"{start_timestamp} - {end_timestamp}" + bin_ids[bin_mask] = bin_id + + prev_thresh_sec = threshold_seconds + + return bin_ids + + else: + # Uniform bin size logic + bin_size_seconds = self.bin_size.total_seconds() + + # Calculate which bin each value belongs to + bin_numbers = ( + lead_time.dt.total_seconds() // bin_size_seconds + ).astype(int) + + bin_ids = pd.Series("", index=lead_time.index) + + # For each unique bin, create bin ID from actual timestamps + for bin_num in bin_numbers.unique(): + bin_mask = bin_numbers == bin_num + + if bin_mask.any(): + # Get the unique value_time values for this bin + bin_value_times = value_time[bin_mask].unique() + + # Convert to datetime if needed and find min/max + bin_value_times_dt = pd.to_datetime(bin_value_times) + start_timestamp = bin_value_times_dt.min() + end_timestamp = start_timestamp + self.bin_size + + # Format as ISO 8601 timestamp range + bin_id = f"{start_timestamp} - {end_timestamp}" + bin_ids[bin_mask] = bin_id + + return bin_ids + + sdf = sdf.withColumn( + self.output_field_name, + func(self.lead_time_field_name, self.value_time_field_name) + ) + return sdf + + def apply_to(self, sdf: ps.DataFrame) -> ps.DataFrame: + """Apply the calculated field to the Spark DataFrame.""" + self.bin_size = self.validate_bin_size_dict(self, sdf) + sdf = self.add_forecast_lead_time(self, sdf) + sdf = self.add_forecast_lead_time_bin(self, sdf) + return sdf + + class ThresholdValueExceeded(CalculatedFieldABC, CalculatedFieldBaseModel): """Adds boolean column indicating if the input value exceeds a threshold. @@ -408,6 +668,7 @@ class HourOfYear(CalculatedFieldABC, CalculatedFieldBaseModel): The name of the column to store the month. Default: "hour_of_year" """ + input_field_name: str = Field( default="value_time" ) @@ -461,6 +722,7 @@ class RowLevelCalculatedFields: - NormalizedFlow - Seasons - ForecastLeadTime + - ForecastLeadTimeBins - ThresholdValueExceeded - DayOfYear - HourOfYear @@ -472,6 +734,7 @@ class RowLevelCalculatedFields: NormalizedFlow = NormalizedFlow Seasons = Seasons ForecastLeadTime = ForecastLeadTime + ForecastLeadTimeBins = ForecastLeadTimeBins ThresholdValueExceeded = ThresholdValueExceeded ThresholdValueNotExceeded = ThresholdValueNotExceeded DayOfYear = DayOfYear diff --git a/src/teehr/models/calculated_fields/timeseries_aware.py b/src/teehr/models/calculated_fields/timeseries_aware.py index 97c1cf6c..3623742a 100644 --- a/src/teehr/models/calculated_fields/timeseries_aware.py +++ b/src/teehr/models/calculated_fields/timeseries_aware.py @@ -66,6 +66,7 @@ class AbovePercentileEventDetection(CalculatedFieldABC, CalculatedFieldBaseModel 'unit_name' ] """ + quantile: float = Field( default=0.85 ) diff --git a/tests/evaluations/test_add_udfs.py b/tests/evaluations/test_add_udfs.py index 54008794..8a467ebe 100644 --- a/tests/evaluations/test_add_udfs.py +++ b/tests/evaluations/test_add_udfs.py @@ -55,6 +55,9 @@ def test_add_row_udfs(tmpdir): sdf = rcf.ForecastLeadTime().apply_to(sdf) _ = sdf.toPandas() + sdf = rcf.ForecastLeadTimeBins().apply_to(sdf) + _ = sdf.toPandas() + sdf = rcf.ThresholdValueExceeded( threshold_field_name="year_2_discharge" ).apply_to(sdf) @@ -276,7 +279,7 @@ def test_add_timeseries_udfs(tmpdir): skip_event_id=True ) sdf = ped.apply_to(sdf) - num_event_timesteps = sdf.filter(sdf.event_above == True).count() + num_event_timesteps = sdf.filter(sdf.event_above).count() assert num_event_timesteps == 14823 # test percentile event detection (return quantile value) @@ -340,7 +343,9 @@ def test_location_event_detection(tmpdir): ped = tcf.AbovePercentileEventDetection() sdf = ev.metrics.add_calculated_fields(ped).query( - group_by=["configuration_name", "primary_location_id", "event_above_id"], + group_by=["configuration_name", + "primary_location_id", + "event_above_id"], include_metrics=[ teehr.Signatures.Maximum( input_field_names=["primary_value"], From 170ef7f9ab4c72b64b42ad5efe264191dbc92df9 Mon Sep 17 00:00:00 2001 From: samland1116 Date: Fri, 5 Dec 2025 15:02:17 -0600 Subject: [PATCH 2/7] update docs and API formatting w.r.t ForecastLeadTimeBins --- .../notebooks/08_adding_calculated_fields.ipynb | 1 + src/teehr/models/calculated_fields/row_level.py | 17 ++++++++--------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb b/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb index 4080b10f..bf063590 100644 --- a/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb +++ b/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb @@ -316,6 +316,7 @@ "- NormalizedFlow\n", "- Seasons\n", "- ForecastLeadTime\n", + "- ForecastLeadTimeBins\n", "- ThresholdValueExceeded\n", "- DayOfYear\n", "\n", diff --git a/src/teehr/models/calculated_fields/row_level.py b/src/teehr/models/calculated_fields/row_level.py index a0fc5bcd..b16c348e 100644 --- a/src/teehr/models/calculated_fields/row_level.py +++ b/src/teehr/models/calculated_fields/row_level.py @@ -287,11 +287,13 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): pd.Timedelta(days=2): pd.Timedelta(hours=12), pd.Timedelta(days=3): pd.Timedelta(days=1) } + Example dict 2: { "2024-11-20 12:00:00": pd.Timedelta(hours=6), "2024-11-21 12:00:00": pd.Timedelta(hours=12), "2024-11-22 12:00:00": pd.Timedelta(days=1) } + Example dict 3: { pd.Timestamp("2024-11-20 12:00:00"): pd.Timedelta(hours=6), pd.Timestamp("2024-11-21 12:00:00"): pd.Timedelta(hours=12), @@ -315,8 +317,7 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): default=pd.Timedelta(days=5) ) - @staticmethod - def validate_bin_size_dict( + def _validate_bin_size_dict( self, sdf: ps.DataFrame ) -> Union[pd.Timedelta, dict]: @@ -397,8 +398,7 @@ def validate_bin_size_dict( "bin_size dict keys must be pd.Timedelta, str, or pd.Timestamp" ) - @staticmethod - def add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: + def _add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: """Calculate forecast lead time if not already present.""" if self.lead_time_field_name not in sdf.columns: flt_cf = ForecastLeadTime( @@ -409,8 +409,7 @@ def add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: sdf = flt_cf.apply_to(sdf) return sdf - @staticmethod - def add_forecast_lead_time_bin( + def _add_forecast_lead_time_bin( self, sdf: ps.DataFrame ) -> ps.DataFrame: @@ -511,9 +510,9 @@ def func(lead_time: pd.Series, value_time: pd.Series) -> pd.Series: def apply_to(self, sdf: ps.DataFrame) -> ps.DataFrame: """Apply the calculated field to the Spark DataFrame.""" - self.bin_size = self.validate_bin_size_dict(self, sdf) - sdf = self.add_forecast_lead_time(self, sdf) - sdf = self.add_forecast_lead_time_bin(self, sdf) + self.bin_size = self._validate_bin_size_dict(self, sdf) + sdf = self._add_forecast_lead_time(self, sdf) + sdf = self._add_forecast_lead_time_bin(self, sdf) return sdf From 7a5fab756dda96f44f0b68016698c6348804c5c7 Mon Sep 17 00:00:00 2001 From: samland1116 Date: Mon, 8 Dec 2025 13:55:20 -0600 Subject: [PATCH 3/7] adds staticmethod decorators to class methods --- src/teehr/models/calculated_fields/row_level.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/teehr/models/calculated_fields/row_level.py b/src/teehr/models/calculated_fields/row_level.py index b16c348e..6f0ea5be 100644 --- a/src/teehr/models/calculated_fields/row_level.py +++ b/src/teehr/models/calculated_fields/row_level.py @@ -317,10 +317,9 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): default=pd.Timedelta(days=5) ) + @ staticmethod def _validate_bin_size_dict( - self, - sdf: ps.DataFrame - ) -> Union[pd.Timedelta, dict]: + self, sdf: ps.DataFrame) -> Union[pd.Timedelta, dict]: """Validate and normalize bin_size dict. Validates that bin_size dict keys are of correct, uniform type and @@ -398,6 +397,7 @@ def _validate_bin_size_dict( "bin_size dict keys must be pd.Timedelta, str, or pd.Timestamp" ) + @staticmethod def _add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: """Calculate forecast lead time if not already present.""" if self.lead_time_field_name not in sdf.columns: @@ -409,6 +409,7 @@ def _add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: sdf = flt_cf.apply_to(sdf) return sdf + @staticmethod def _add_forecast_lead_time_bin( self, sdf: ps.DataFrame From e23cccdeca3d265516182f957d947cc44e5675cf Mon Sep 17 00:00:00 2001 From: samland1116 Date: Tue, 9 Dec 2025 14:51:46 -0600 Subject: [PATCH 4/7] refactors ForecastLeadTimeBins --- .../models/calculated_fields/row_level.py | 451 +++++++++++------- 1 file changed, 287 insertions(+), 164 deletions(-) diff --git a/src/teehr/models/calculated_fields/row_level.py b/src/teehr/models/calculated_fields/row_level.py index 6f0ea5be..2df04a85 100644 --- a/src/teehr/models/calculated_fields/row_level.py +++ b/src/teehr/models/calculated_fields/row_level.py @@ -275,30 +275,107 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): The name of the column to store the lead time bin ID. Default: "forecast_lead_time_bin" - bin_size: - Either a single pd.Timedelta for uniform bin sizes, or a dict mapping - threshold pd.Timedelta values to pd.Timedelta bin sizes. Keys - represent the upper bound of each threshold range. Keys can also be - string or pd.Timestamp values, which will be converted to pd.Timedelta - relative to the minimum value_time in the DataFrame. - Default: 5 days - - Example dict 1: { - pd.Timedelta(days=1): pd.Timedelta(hours=6), - pd.Timedelta(days=2): pd.Timedelta(hours=12), - pd.Timedelta(days=3): pd.Timedelta(days=1) - } + Defines how forecast lead times are binned. Three input formats are + supported: + + 1. **Single pd.Timedelta** (uniform binning): + Creates equal-width bins of the specified duration. + + Example: + pd.Timedelta(hours=6) + + Output bin IDs: + "PT0H_PT6H", "PT6H_PT12H", "PT12H_PT18H", ... + + 2. **List of dicts** (variable binning with auto-generated IDs): + Creates bins with custom ranges. Bin IDs are auto-generated as + ISO 8601 duration ranges. + + Example: + [ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(days=1)}, + {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=2)}, + ] + + Output bin IDs: + "PT0H_PT6H", "PT6H_PT12H", "PT12H_P1D", "P1D_P2D" + + 3. **Dict of dicts** (variable binning with custom IDs): + Creates bins with custom ranges and user-defined bin identifiers. + + Example: + { + 'short_range': {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + 'medium_range': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(days=1)}, + 'long_range': {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=3)}, + } + + Output bin IDs: + "short_range", "medium_range", "long_range" + + Default: pd.Timedelta(days=5) - Example dict 2: { - "2024-11-20 12:00:00": pd.Timedelta(hours=6), - "2024-11-21 12:00:00": pd.Timedelta(hours=12), - "2024-11-22 12:00:00": pd.Timedelta(days=1) - } + Notes + ----- + - Bin ranges are [start_inclusive, end_exclusive), except for the final + bin which is inclusive of all remaining lead times. + - If the maximum lead time in the data exceeds the last user-defined bin, + an overflow bin is automatically created: + - For auto-generated IDs: Uses ISO 8601 duration format + - For custom IDs: Appends "overflow" as the bin ID + - Bin IDs use ISO 8601 duration format (e.g., "PT6H" for 6 hours, "P1DT12H" + for 1 day and 12 hours) for auto-generated bins. + - Custom bin IDs can use any string format. + + Examples + -------- + Uniform 6-hour bins: + + .. code-block:: python + + fcst_bins = ForecastLeadTimeBins(bin_size=pd.Timedelta(hours=6)) + # Creates bins: PT0H_PT6H, PT6H_PT12H, PT12H_PT18H, ... + + Variable bins with auto-generated IDs: + + .. code-block:: python + + fcst_bins = ForecastLeadTimeBins( + bin_size=[ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(days=1)}, + {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=3)}, + ] + ) + # Creates bins: PT0H_PT6H, PT6H_P1D, P1D_P3D - Example dict 3: { - pd.Timestamp("2024-11-20 12:00:00"): pd.Timedelta(hours=6), - pd.Timestamp("2024-11-21 12:00:00"): pd.Timedelta(hours=12), - pd.Timestamp("2024-11-22 12:00:00"): pd.Timedelta(days=1) - } + Variable bins with custom IDs: + + .. code-block:: python + + fcst_bins = ForecastLeadTimeBins( + bin_size={ + 'nowcast': {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + 'short_term': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(days=1)}, + 'medium_term': {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=5)}, + } + ) + # Creates bins: nowcast, short_term, medium_term """ value_time_field_name: str = Field( @@ -313,89 +390,109 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): output_field_name: str = Field( default="forecast_lead_time_bin" ) - bin_size: Union[pd.Timedelta, dict] = Field( + bin_size: Union[pd.Timedelta, list, dict] = Field( default=pd.Timedelta(days=5) ) - @ staticmethod - def _validate_bin_size_dict( - self, sdf: ps.DataFrame) -> Union[pd.Timedelta, dict]: - """Validate and normalize bin_size dict. + @staticmethod + def _validate_bin_size_dict(self) -> Union[pd.Timedelta, list, dict]: + """Validate and normalize bin_size input. - Validates that bin_size dict keys are of correct, uniform type and - converts them to pd.Timedelta if needed. Validates that bin_size dict - values are of type pd.Timedelta. Validates that dict is not empty. + Validates and converts bin_size to a standardized format: + - Single pd.Timedelta: returns as-is + - List of dicts: validates structure and converts to internal format + - Dict of dicts: validates structure and keeps custom bin IDs - If bin_size is already a pd.Timedelta or dict with pd.Timedelta keys, - returns it unchanged. Otherwise, converts string/datetime keys to - pd.Timedelta by calculating the difference from the minimum value_time. + Returns a normalized structure for internal processing. """ - # If not a dict or already has Timedelta keys, return as-is - if not isinstance(self.bin_size, dict) and isinstance( - self.bin_size, pd.Timedelta - ): + # Single Timedelta - return as-is + if isinstance(self.bin_size, pd.Timedelta): return self.bin_size - # raise error if dict is empty - if not self.bin_size: - raise ValueError("bin_size dict cannot be empty") - - # Ensure all keys are of the same type - first_key = next(iter(self.bin_size.keys())) - for key in self.bin_size.keys(): - if not isinstance(key, type(first_key)): - raise TypeError( - "All bin_size dict keys must be of the same type" - ) - - # Ensure all values are pd.Timedelta - for value in self.bin_size.values(): - if not isinstance(value, pd.Timedelta): - raise TypeError( - "All bin_size dict values must be of type pd.Timedelta" - ) + # List of dicts format + if isinstance(self.bin_size, list): + if not self.bin_size: + raise ValueError("bin_size list cannot be empty") - # If already Timedelta keys, return as-is - if isinstance(first_key, pd.Timedelta): - return self.bin_size - - # Convert string or datetime keys to Timedelta - if isinstance(first_key, (str, pd.Timestamp)): - # Get minimum value_time from the dataframe - min_value_time_row = sdf.select( - self.value_time_field_name - ).orderBy(self.value_time_field_name).first() - - # Extract value_time and convert to datetime - min_value_time = pd.to_datetime( - min_value_time_row[self.value_time_field_name] - ) + # Validate each dict has required keys + for i, bin_dict in enumerate(self.bin_size): + if not isinstance(bin_dict, dict): + raise TypeError( + f"Item {i} in bin_size list must be a dict" + ) - # Convert dict keys from string/datetime to Timedelta - converted_dict = {} + required_keys = {'start_inclusive', 'end_exclusive'} + if not required_keys.issubset(bin_dict.keys()): + raise ValueError( + f"Item {i} missing required keys: {required_keys}" + ) + + # Validate that values are Timedelta + if not isinstance(bin_dict['start_inclusive'], pd.Timedelta): + raise TypeError( + f"Item {i} 'start_inclusive' must be pd.Timedelta" + ) + if not isinstance(bin_dict['end_exclusive'], pd.Timedelta): + raise TypeError( + f"Item {i} 'end_exclusive' must be pd.Timedelta" + ) + + # Convert to internal format: list of tuples (start, end, bin_id) + # For list format, bin_id is None + normalized = [] + for bin_dict in self.bin_size: + start = bin_dict['start_inclusive'] + end = bin_dict['end_exclusive'] + normalized.append((start, end, None)) + + return normalized + + # Dict of dicts format + if isinstance(self.bin_size, dict): + if not self.bin_size: + raise ValueError("bin_size dict cannot be empty") + + # Validate structure for key, value in self.bin_size.items(): - # Convert key to datetime if it's a string - if isinstance(key, str): - try: - key_datetime = pd.to_datetime(key) - except Exception as e: - raise ValueError( - f"Error converting key '{key}' to datetime: {e}" - ) - else: - key_datetime = key - - # Calculate Timedelta from min_value_time - timedelta_key = key_datetime - min_value_time - converted_dict[timedelta_key] = value - - return converted_dict - - else: - # raise error for unsupported key types - raise TypeError( - "bin_size dict keys must be pd.Timedelta, str, or pd.Timestamp" - ) + if not isinstance(key, str): + raise TypeError( + f"Dict keys must be strings (custom bin IDs), got \ + {type(key)}" + ) + + if not isinstance(value, dict): + raise TypeError( + "Dict values must be dicts with bin specification" + ) + + required_keys = {'start_inclusive', 'end_exclusive'} + if not required_keys.issubset(value.keys()): + raise ValueError( + f"Bin '{key}' missing required keys. Must have: \ + {required_keys}" + ) + + if not isinstance(value['start_inclusive'], pd.Timedelta): + raise TypeError( + f"Bin '{key}' 'start_inclusive' must be pd.Timedelta" + ) + if not isinstance(value['end_exclusive'], pd.Timedelta): + raise TypeError( + f"Bin '{key}' 'end_exclusive' must be pd.Timedelta" + ) + + # Convert to internal format: list of tuples + normalized = [] + for custom_id, bin_dict in self.bin_size.items(): + start = bin_dict['start_inclusive'] + end = bin_dict['end_exclusive'] + normalized.append((start, end, custom_id)) + + return normalized + + raise TypeError( + "bin_size must be pd.Timedelta, list of dicts, or dict of dicts" + ) @staticmethod def _add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: @@ -416,102 +513,128 @@ def _add_forecast_lead_time_bin( ) -> ps.DataFrame: """Add forecast lead time bin column.""" - @pandas_udf(returnType=T.StringType()) - def func(lead_time: pd.Series, value_time: pd.Series) -> pd.Series: - if isinstance(self.bin_size, dict): - # Sort thresholds for consistent processing - sorted_thresholds = sorted(self.bin_size.items(), - key=lambda x: x[0].total_seconds()) - - bin_ids = pd.Series("", index=lead_time.index) - prev_thresh_sec = 0 - - for i, (threshold, bin_size) in enumerate(sorted_thresholds): - threshold_seconds = threshold.total_seconds() - bin_size_seconds = bin_size.total_seconds() + def _timedelta_to_iso_duration(td: pd.Timedelta) -> str: + """Convert pd.Timedelta to ISO 8601 duration string.""" + iso_str = td.isoformat() + # Remove trailing 0M0S, 0S, etc. for cleaner output + iso_str = iso_str.replace( + '0M0S', '').replace( + '0S', '').replace( + '0M', '') + # Handle edge case where we removed everything after 'T' + if iso_str.endswith('T'): + iso_str = iso_str[:-1] + 'T0S' + return iso_str - is_last = (i == len(sorted_thresholds) - 1) - - # Mask for values in this threshold range - if is_last: - mask = lead_time.dt.total_seconds() >= prev_thresh_sec - else: - mask = ( - lead_time.dt.total_seconds() >= prev_thresh_sec - ) & ( - lead_time.dt.total_seconds() < threshold_seconds - ) - - if mask.any(): - # Calculate which bin each value belongs to - bins_in_range = ( - (lead_time[mask].dt.total_seconds() - - prev_thresh_sec) - // bin_size_seconds - ).astype(int) - - # create bin ID from actual timestamps - for bin_num in bins_in_range.unique(): - bin_mask = mask & (bins_in_range == bin_num) - - if bin_mask.any(): - # Get the unique value_time values for this bin - bin_value_times = value_time[bin_mask].unique() - - # Convert to datetime and find min/max - bin_value_times_dt = pd.to_datetime( - bin_value_times - ) - start_timestamp = bin_value_times_dt.min() - end_timestamp = start_timestamp + bin_size - - # Format as ISO 8601 timestamp range - bin_id = f"{start_timestamp} - {end_timestamp}" - bin_ids[bin_mask] = bin_id - - prev_thresh_sec = threshold_seconds - - return bin_ids - - else: - # Uniform bin size logic + @pandas_udf(returnType=T.StringType()) + def func(lead_time: pd.Series) -> pd.Series: + # Single Timedelta - uniform binning + if isinstance(self.bin_size, pd.Timedelta): bin_size_seconds = self.bin_size.total_seconds() - # Calculate which bin each value belongs to bin_numbers = ( lead_time.dt.total_seconds() // bin_size_seconds ).astype(int) bin_ids = pd.Series("", index=lead_time.index) - # For each unique bin, create bin ID from actual timestamps for bin_num in bin_numbers.unique(): bin_mask = bin_numbers == bin_num if bin_mask.any(): - # Get the unique value_time values for this bin - bin_value_times = value_time[bin_mask].unique() + start_td = pd.Timedelta( + seconds=bin_num * bin_size_seconds + ) + end_td = pd.Timedelta( + seconds=(bin_num + 1) * bin_size_seconds + ) - # Convert to datetime if needed and find min/max - bin_value_times_dt = pd.to_datetime(bin_value_times) - start_timestamp = bin_value_times_dt.min() - end_timestamp = start_timestamp + self.bin_size + # Convert to ISO duration format + start_iso = _timedelta_to_iso_duration(start_td) + end_iso = _timedelta_to_iso_duration(end_td) + bin_id = f"{start_iso}_{end_iso}" - # Format as ISO 8601 timestamp range - bin_id = f"{start_timestamp} - {end_timestamp}" bin_ids[bin_mask] = bin_id return bin_ids + # List/Dict format - dynamic binning with explicit ranges + # self.bin_size is now a list of tuples: (start, end, bin_id) + # bin_id is None for auto-generated, or a string for custom + else: + bin_ids = pd.Series("", index=lead_time.index) + lead_time_seconds = lead_time.dt.total_seconds() + + # Check if we need to add an overflow bin + max_lead_time = lead_time.max() + last_bin_end = self.bin_size[-1][1] + + # Create working copy of bin_size + bins_to_use = [] + + # Convert all bins, generating ISO format for None bin_ids + for start_td, end_td, bin_id in self.bin_size: + if bin_id is None: + # Auto-generated: create ISO duration format + start_iso = _timedelta_to_iso_duration(start_td) + end_iso = _timedelta_to_iso_duration(end_td) + final_bin_id = f"{start_iso}_{end_iso}" + else: + # Custom ID: use as-is + final_bin_id = bin_id + + bins_to_use.append((start_td, end_td, final_bin_id)) + + # If max lead time exceeds last bin, create overflow bin + if max_lead_time >= last_bin_end: + overflow_start = last_bin_end + overflow_end = max_lead_time + + # Determine overflow bin_id + if self.bin_size[-1][2] is None: + # Auto-generated format: use ISO duration strings + start_iso = _timedelta_to_iso_duration(overflow_start) + end_iso = _timedelta_to_iso_duration(overflow_end) + overflow_bin_id = f"{start_iso}_{end_iso}" + else: + # Custom ID format: append suffix + overflow_bin_id = "overflow" + + bins_to_use.append( + (overflow_start, overflow_end, overflow_bin_id) + ) + + for i, (start_td, end_td, bin_id) in enumerate(bins_to_use): + start_seconds = start_td.total_seconds() + end_seconds = end_td.total_seconds() + + # Check if this is the last bin (including overflow bin) + is_last_bin = (i == len(bins_to_use) - 1) + + if is_last_bin: + # Last bin is inclusive of end_exclusive + mask = lead_time_seconds >= start_seconds + else: + # All other bins are [start, end) + mask = ( + (lead_time_seconds >= start_seconds) & + (lead_time_seconds < end_seconds) + ) + + if mask.any(): + bin_ids[mask] = bin_id + + return bin_ids + sdf = sdf.withColumn( self.output_field_name, - func(self.lead_time_field_name, self.value_time_field_name) + func(self.lead_time_field_name) ) return sdf def apply_to(self, sdf: ps.DataFrame) -> ps.DataFrame: """Apply the calculated field to the Spark DataFrame.""" - self.bin_size = self._validate_bin_size_dict(self, sdf) + self.bin_size = self._validate_bin_size_dict(self) sdf = self._add_forecast_lead_time(self, sdf) sdf = self._add_forecast_lead_time_bin(self, sdf) return sdf From b12700605fa095cf430803c955a3e3b72cc60e6d Mon Sep 17 00:00:00 2001 From: samland1116 Date: Wed, 10 Dec 2025 12:42:36 -0600 Subject: [PATCH 5/7] adds new tests for ForecastLeadTimeBins --- tests/data/setup_v0_4_ensemble_study.py | 120 ++++++++++++++++++ tests/evaluations/test_add_udfs.py | 154 ++++++++++++++++++++++++ 2 files changed, 274 insertions(+) create mode 100644 tests/data/setup_v0_4_ensemble_study.py diff --git a/tests/data/setup_v0_4_ensemble_study.py b/tests/data/setup_v0_4_ensemble_study.py new file mode 100644 index 00000000..e954c421 --- /dev/null +++ b/tests/data/setup_v0_4_ensemble_study.py @@ -0,0 +1,120 @@ +from pathlib import Path +from teehr import Configuration +from teehr.models.filters import TableFilter +from teehr.evaluation.evaluation import Evaluation +from teehr import SignatureTimeseriesGenerators as sts +from teehr import BenchmarkForecastGenerators as bm + +TEST_STUDY_DATA_DIR_v0_4 = Path(Path.home(), "repos", "teehr", "tests", "data", "test_study") + + +def setup_v0_4_ensemble_study(tmpdir): + """Create a test evaluation with ensemble forecasts using teehr.""" + usgs_location = Path( + TEST_STUDY_DATA_DIR_v0_4, "geo", "USGS_PlatteRiver_location.parquet" + ) + + secondary_filename = "MEFP.MBRFC.DNVC2LOCAL.SQIN.xml" + secondary_filepath = Path( + TEST_STUDY_DATA_DIR_v0_4, + "timeseries", + secondary_filename + ) + primary_filepath = Path( + TEST_STUDY_DATA_DIR_v0_4, + "timeseries", + "usgs_hefs_06711565.parquet" + ) + + ev = Evaluation(dir_path=tmpdir) + ev.enable_logging() + ev.clone_template() + + ev.locations.load_spatial( + in_path=usgs_location + ) + ev.location_crosswalks.load_csv( + in_path=Path( + TEST_STUDY_DATA_DIR_v0_4, "geo", "hefs_usgs_crosswalk.csv" + ) + ) + ev.configurations.add( + Configuration( + name="MEFP", + type="secondary", + description="MBRFC HEFS Data" + ) + ) + ev.configurations.add( + Configuration( + name="usgs_observations", + type="primary", + description="USGS observed test data" + ) + ) + constant_field_values = { + "unit_name": "ft^3/s", + "variable_name": "streamflow_hourly_inst", + } + ev.secondary_timeseries.load_fews_xml( + in_path=secondary_filepath, + constant_field_values=constant_field_values + ) + ev.primary_timeseries.load_parquet( + in_path=primary_filepath + ) + + # Calculate annual hourly normals from USGS observations. + input_ts = TableFilter() + input_ts.table_name = "primary_timeseries" + + ts_normals = sts.Normals() + ts_normals.temporal_resolution = "hour_of_year" # the default + ts_normals.summary_statistic = "mean" # the default + + ev.generate.signature_timeseries( + method=ts_normals, + input_table_filter=input_ts, + start_datetime="2024-11-19 12:00:00", + end_datetime="2024-11-21 13:00:00", + timestep="1 hour", + fillna=False + ).write() + + # Add reference forecast based on climatology. + ev.configurations.add( + [ + Configuration( + name="benchmark_forecast_hourly_normals", + type="secondary", + description="Reference forecast based on USGS climatology summarized by hour of year" # noqa + ) + ] + ) + ref_fcst = bm.ReferenceForecast() + ref_fcst.aggregate_reference_timeseries = True + + reference_ts = TableFilter() + reference_ts.table_name = "primary_timeseries" + reference_ts.filters = [ + "variable_name = 'streamflow_hour_of_year_mean'", + "unit_name = 'ft^3/s'" + ] + + template_ts = TableFilter() + template_ts.table_name = "secondary_timeseries" + template_ts.filters = [ + "variable_name = 'streamflow_hourly_inst'", + "unit_name = 'ft^3/s'", + "member = '1993'" + ] + ev.generate.benchmark_forecast( + method=ref_fcst, + reference_table_filter=reference_ts, + template_table_filter=template_ts, + output_configuration_name="benchmark_forecast_hourly_normals" + ).write(destination_table="secondary_timeseries") + + ev.joined_timeseries.create(execute_scripts=False) + + return ev diff --git a/tests/evaluations/test_add_udfs.py b/tests/evaluations/test_add_udfs.py index 8a467ebe..62235957 100644 --- a/tests/evaluations/test_add_udfs.py +++ b/tests/evaluations/test_add_udfs.py @@ -14,6 +14,7 @@ from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from data.setup_v0_3_study import setup_v0_3_study # noqa +from data.setup_v0_4_ensemble_study import setup_v0_4_ensemble_study # noqa def test_add_row_udfs_null_reference(tmpdir): @@ -128,6 +129,159 @@ def test_add_row_udfs(tmpdir): ev.spark.stop() +def test_forecast_lead_time_bins(tmpdir): + """Test ForecastLeadTimeBins UDF.""" + ev = setup_v0_4_ensemble_study(tmpdir) + + # test with single bin size + fcst_bins_static = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=pd.Timedelta(hours=6) + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_static, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 9 + + # try with dynamic bin sizes that DO encompass full lead time range + bin = [ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': pd.Timedelta(days=2)}, + {'start_inclusive': pd.Timedelta(days=2), + 'end_exclusive': pd.Timedelta(days=3)}, + ] + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin, + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + + # try with dynamic bin sizes that DO NOT encompass full lead time range + bin = [ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + ] + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin, + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 6 + assert 'P1DT12H_P2DT0H' in [row['forecast_lead_time_bin'] for row in + sorted_sdf.select( + 'forecast_lead_time_bin' + ).distinct().collect()] + + # try with dynamic bin sizes w/ string dict keys that DO encompass full + # lead time range + bin = { + 'bin_1': {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + 'bin_2': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + 'bin_3': {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + 'bin_4': {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + 'bin_5': {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + 'bin_6': {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': pd.Timedelta(days=2)}, + 'bin_7': {'start_inclusive': pd.Timedelta(days=2), + 'end_exclusive': pd.Timedelta(days=3)}, + } + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + + # try with dynamic bin sizes w/ string dict keys that DO NOT encompass + # full lead time range + bin = { + 'bin_1': {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + 'bin_2': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + 'bin_3': {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + 'bin_4': {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + 'bin_5': {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + 'bin_6': {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': pd.Timedelta(days=2)}, + } + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + assert 'overflow' in [row['forecast_lead_time_bin'] for row in + sorted_sdf.select( + 'forecast_lead_time_bin' + ).distinct().collect()] + + def test_add_timeseries_udfs(tmpdir): """Test adding a timeseries aware UDF.""" # utilize e0_2_location_example from s3 to satisfy baseflow POR reqs From 3f1bccaf54748dbd661acbce19d75bea41b52b53 Mon Sep 17 00:00:00 2001 From: samlamont Date: Wed, 10 Dec 2025 14:26:11 -0500 Subject: [PATCH 6/7] update tests --- tests/data/setup_v0_4_ensemble_study.py | 2 +- tests/evaluations/test_add_udfs.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/data/setup_v0_4_ensemble_study.py b/tests/data/setup_v0_4_ensemble_study.py index e954c421..094ea6f5 100644 --- a/tests/data/setup_v0_4_ensemble_study.py +++ b/tests/data/setup_v0_4_ensemble_study.py @@ -5,7 +5,7 @@ from teehr import SignatureTimeseriesGenerators as sts from teehr import BenchmarkForecastGenerators as bm -TEST_STUDY_DATA_DIR_v0_4 = Path(Path.home(), "repos", "teehr", "tests", "data", "test_study") +TEST_STUDY_DATA_DIR_v0_4 = Path(Path.cwd(), "tests", "data", "test_study") def setup_v0_4_ensemble_study(tmpdir): diff --git a/tests/evaluations/test_add_udfs.py b/tests/evaluations/test_add_udfs.py index 62235957..083d016e 100644 --- a/tests/evaluations/test_add_udfs.py +++ b/tests/evaluations/test_add_udfs.py @@ -539,6 +539,12 @@ def test_location_event_detection(tmpdir): dir=tempdir ) ) + test_forecast_lead_time_bins( + tempfile.mkdtemp( + prefix="1b-", + dir=tempdir + ) + ) test_add_timeseries_udfs( tempfile.mkdtemp( prefix="2-", From b3591d971090e3513b142cb3a864ff20163ae373 Mon Sep 17 00:00:00 2001 From: samland1116 Date: Wed, 10 Dec 2025 16:20:40 -0600 Subject: [PATCH 7/7] adds additional input type support, support for mixed input types --- .../models/calculated_fields/row_level.py | 204 ++++++++++++------ tests/evaluations/test_add_udfs.py | 34 +++ 2 files changed, 173 insertions(+), 65 deletions(-) diff --git a/src/teehr/models/calculated_fields/row_level.py b/src/teehr/models/calculated_fields/row_level.py index 2df04a85..fb379965 100644 --- a/src/teehr/models/calculated_fields/row_level.py +++ b/src/teehr/models/calculated_fields/row_level.py @@ -3,6 +3,7 @@ from typing import Union from pydantic import Field import pandas as pd +from datetime import timedelta import pyspark.sql.types as T from pyspark.sql.functions import pandas_udf import pyspark.sql as ps @@ -275,32 +276,37 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): The name of the column to store the lead time bin ID. Default: "forecast_lead_time_bin" - bin_size: - Defines how forecast lead times are binned. Three input formats are - supported: + Defines how forecast lead times are binned. Accepts pd.Timedelta, + datetime.timedelta, or timedelta strings (e.g., '6 hours', '1 day'). + Three input formats are supported: - 1. **Single pd.Timedelta** (uniform binning): + 1. **Single timedelta** (uniform binning): Creates equal-width bins of the specified duration. - Example: + Examples: pd.Timedelta(hours=6) + timedelta(hours=6) + '6 hours' + '6h' Output bin IDs: "PT0H_PT6H", "PT6H_PT12H", "PT12H_PT18H", ... 2. **List of dicts** (variable binning with auto-generated IDs): Creates bins with custom ranges. Bin IDs are auto-generated as - ISO 8601 duration ranges. + ISO 8601 duration ranges. Values can be pd.Timedelta, + datetime.timedelta, or timedelta strings. - Example: + Examples: [ {'start_inclusive': pd.Timedelta(hours=0), 'end_exclusive': pd.Timedelta(hours=6)}, - {'start_inclusive': pd.Timedelta(hours=6), - 'end_exclusive': pd.Timedelta(hours=12)}, - {'start_inclusive': pd.Timedelta(hours=12), - 'end_exclusive': pd.Timedelta(days=1)}, - {'start_inclusive': pd.Timedelta(days=1), - 'end_exclusive': pd.Timedelta(days=2)}, + {'start_inclusive': '6 hours', + 'end_exclusive': '12 hours'}, + {'start_inclusive': timedelta(hours=12), + 'end_exclusive': '1 day'}, + {'start_inclusive': '1 day', + 'end_exclusive': '2 days'}, ] Output bin IDs: @@ -308,15 +314,17 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): 3. **Dict of dicts** (variable binning with custom IDs): Creates bins with custom ranges and user-defined bin identifiers. + Values can be pd.Timedelta, datetime.timedelta, or timedelta + strings. - Example: + Examples: { - 'short_range': {'start_inclusive': pd.Timedelta(hours=0), - 'end_exclusive': pd.Timedelta(hours=6)}, + 'short_range': {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, 'medium_range': {'start_inclusive': pd.Timedelta(hours=6), - 'end_exclusive': pd.Timedelta(days=1)}, - 'long_range': {'start_inclusive': pd.Timedelta(days=1), - 'end_exclusive': pd.Timedelta(days=3)}, + 'end_exclusive': timedelta(days=1)}, + 'long_range': {'start_inclusive': '1 day', + 'end_exclusive': '3 days'}, } Output bin IDs: @@ -326,6 +334,12 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): Notes ----- + - Timedelta values can be specified as: + - pd.Timedelta objects (e.g., pd.Timedelta(hours=6)) + - datetime.timedelta objects (e.g., timedelta(hours=6)) + - Strings (e.g., '6 hours', '1 day', '1d 12h', 'PT6H') + - All timedelta inputs are internally converted to pd.Timedelta for + processing. - Bin ranges are [start_inclusive, end_exclusive), except for the final bin which is inclusive of all remaining lead times. - If the maximum lead time in the data exceeds the last user-defined bin, @@ -338,41 +352,50 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): Examples -------- - Uniform 6-hour bins: + Uniform 6-hour bins using different input types: .. code-block:: python + # Using pd.Timedelta fcst_bins = ForecastLeadTimeBins(bin_size=pd.Timedelta(hours=6)) - # Creates bins: PT0H_PT6H, PT6H_PT12H, PT12H_PT18H, ... - Variable bins with auto-generated IDs: + # Using datetime.timedelta + from datetime import timedelta + fcst_bins = ForecastLeadTimeBins(bin_size=timedelta(hours=6)) + + # Using string + fcst_bins = ForecastLeadTimeBins(bin_size='6 hours') + + # All create bins: PT0H_PT6H, PT6H_PT12H, PT12H_PT18H, ... + + Variable bins with auto-generated IDs using mixed types: .. code-block:: python fcst_bins = ForecastLeadTimeBins( bin_size=[ - {'start_inclusive': pd.Timedelta(hours=0), - 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, {'start_inclusive': pd.Timedelta(hours=6), - 'end_exclusive': pd.Timedelta(days=1)}, - {'start_inclusive': pd.Timedelta(days=1), - 'end_exclusive': pd.Timedelta(days=3)}, + 'end_exclusive': '1 day'}, + {'start_inclusive': timedelta(days=1), + 'end_exclusive': '3 days'}, ] ) # Creates bins: PT0H_PT6H, PT6H_P1D, P1D_P3D - Variable bins with custom IDs: + Variable bins with custom IDs using strings: .. code-block:: python fcst_bins = ForecastLeadTimeBins( bin_size={ - 'nowcast': {'start_inclusive': pd.Timedelta(hours=0), - 'end_exclusive': pd.Timedelta(hours=6)}, - 'short_term': {'start_inclusive': pd.Timedelta(hours=6), - 'end_exclusive': pd.Timedelta(days=1)}, - 'medium_term': {'start_inclusive': pd.Timedelta(days=1), - 'end_exclusive': pd.Timedelta(days=5)}, + 'nowcast': {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, + 'short_term': {'start_inclusive': '6 hours', + 'end_exclusive': '1 day'}, + 'medium_term': {'start_inclusive': '1 day', + 'end_exclusive': '5 days'}, } ) # Creates bins: nowcast, short_term, medium_term @@ -390,7 +413,7 @@ class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): output_field_name: str = Field( default="forecast_lead_time_bin" ) - bin_size: Union[pd.Timedelta, list, dict] = Field( + bin_size: Union[pd.Timedelta, timedelta, str, list, dict] = Field( default=pd.Timedelta(days=5) ) @@ -405,9 +428,37 @@ def _validate_bin_size_dict(self) -> Union[pd.Timedelta, list, dict]: Returns a normalized structure for internal processing. """ - # Single Timedelta - return as-is - if isinstance(self.bin_size, pd.Timedelta): - return self.bin_size + def _to_pd_timedelta(value, field_name, context): + """Convert datetime.timedelta or string to pd.Timedelta.""" + if isinstance(value, pd.Timedelta): + return value + elif isinstance(value, timedelta): + return pd.Timedelta(value) + elif isinstance(value, str): + try: + temp = pd.Timedelta(value) + if temp < pd.Timedelta(seconds=1) and \ + temp != pd.Timedelta(0): + raise ValueError( + "Timedelta must be at least 1 second" + ) + return temp + except ValueError as e: + raise ValueError( + f"{context} '{field_name}' has invalid timedelta" + f" string: '{value}'. " + f"Error: {e}" + ) + else: + raise TypeError( + f"{context} '{field_name}' must be pd.Timedelta," + " datetime.timedelta, or a valid timedelta string," + f" got {type(value)}" + ) + + # Single Timedelta - convert if needed + if isinstance(self.bin_size, (pd.Timedelta, timedelta, str)): + return _to_pd_timedelta(self.bin_size, 'bin_size', 'bin_size') # List of dicts format if isinstance(self.bin_size, list): @@ -424,25 +475,36 @@ def _validate_bin_size_dict(self) -> Union[pd.Timedelta, list, dict]: required_keys = {'start_inclusive', 'end_exclusive'} if not required_keys.issubset(bin_dict.keys()): raise ValueError( - f"Item {i} missing required keys: {required_keys}" + f"Item {i} missing required keys. " + f"Must have: {required_keys}" ) - # Validate that values are Timedelta - if not isinstance(bin_dict['start_inclusive'], pd.Timedelta): - raise TypeError( - f"Item {i} 'start_inclusive' must be pd.Timedelta" - ) - if not isinstance(bin_dict['end_exclusive'], pd.Timedelta): - raise TypeError( - f"Item {i} 'end_exclusive' must be pd.Timedelta" - ) + # Validate and convert values to pd.Timedelta + start = _to_pd_timedelta( + bin_dict['start_inclusive'], + 'start_inclusive', + f"Item {i}" + ) + end = _to_pd_timedelta( + bin_dict['end_exclusive'], + 'end_exclusive', + f"Item {i}" + ) # Convert to internal format: list of tuples (start, end, bin_id) - # For list format, bin_id is None + # For list format, bin_id is None (will be auto-generated as ISO) normalized = [] for bin_dict in self.bin_size: - start = bin_dict['start_inclusive'] - end = bin_dict['end_exclusive'] + start = _to_pd_timedelta( + bin_dict['start_inclusive'], + 'start_inclusive', + 'bin_dict' + ) + end = _to_pd_timedelta( + bin_dict['end_exclusive'], + 'end_exclusive', + 'bin_dict' + ) normalized.append((start, end, None)) return normalized @@ -456,8 +518,8 @@ def _validate_bin_size_dict(self) -> Union[pd.Timedelta, list, dict]: for key, value in self.bin_size.items(): if not isinstance(key, str): raise TypeError( - f"Dict keys must be strings (custom bin IDs), got \ - {type(key)}" + f"Dict keys must be strings (custom bin IDs), got " + f"{type(key)}" ) if not isinstance(value, dict): @@ -468,30 +530,42 @@ def _validate_bin_size_dict(self) -> Union[pd.Timedelta, list, dict]: required_keys = {'start_inclusive', 'end_exclusive'} if not required_keys.issubset(value.keys()): raise ValueError( - f"Bin '{key}' missing required keys. Must have: \ - {required_keys}" + f"Bin '{key}' missing required keys. Must have: " + f"{required_keys}" ) - if not isinstance(value['start_inclusive'], pd.Timedelta): - raise TypeError( - f"Bin '{key}' 'start_inclusive' must be pd.Timedelta" - ) - if not isinstance(value['end_exclusive'], pd.Timedelta): - raise TypeError( - f"Bin '{key}' 'end_exclusive' must be pd.Timedelta" - ) + # Validate and convert to pd.Timedelta + _to_pd_timedelta( + value['start_inclusive'], + 'start_inclusive', + f"Bin '{key}'" + ) + _to_pd_timedelta( + value['end_exclusive'], + 'end_exclusive', + f"Bin '{key}'" + ) # Convert to internal format: list of tuples normalized = [] for custom_id, bin_dict in self.bin_size.items(): - start = bin_dict['start_inclusive'] - end = bin_dict['end_exclusive'] + start = _to_pd_timedelta( + bin_dict['start_inclusive'], + 'start_inclusive', + f"Bin '{custom_id}'" + ) + end = _to_pd_timedelta( + bin_dict['end_exclusive'], + 'end_exclusive', + f"Bin '{custom_id}'" + ) normalized.append((start, end, custom_id)) return normalized raise TypeError( - "bin_size must be pd.Timedelta, list of dicts, or dict of dicts" + "bin_size must be pd.Timedelta, datetime.timedelta, " + "a valid timedelta string, list of dicts, or dict of dicts" ) @staticmethod diff --git a/tests/evaluations/test_add_udfs.py b/tests/evaluations/test_add_udfs.py index 083d016e..acb0c6d4 100644 --- a/tests/evaluations/test_add_udfs.py +++ b/tests/evaluations/test_add_udfs.py @@ -9,6 +9,7 @@ import numpy as np import baseflow import pandas as pd +from datetime import timedelta import sys from pathlib import Path @@ -281,6 +282,39 @@ def test_forecast_lead_time_bins(tmpdir): 'forecast_lead_time_bin' ).distinct().collect()] + # try mixed type dynamic bin sizes w/ string dict keys that DO encompass + # the full lead time range + bin = { + 'bin_1': {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, + 'bin_2': {'start_inclusive': pd.Timedelta('6 hours'), + 'end_exclusive': pd.Timedelta(hours=12)}, + 'bin_3': {'start_inclusive': timedelta(hours=12), + 'end_exclusive': timedelta(hours=18)}, + 'bin_4': {'start_inclusive': '18 hours', + 'end_exclusive': pd.Timedelta('1 days')}, + 'bin_5': {'start_inclusive': '1 days', + 'end_exclusive': timedelta(days=1, hours=12)}, + 'bin_6': {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': '2 days'}, + 'bin_7': {'start_inclusive': timedelta(days=2), + 'end_exclusive': '3 days'}, + } + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + def test_add_timeseries_udfs(tmpdir): """Test adding a timeseries aware UDF."""