diff --git a/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb b/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb index 4080b10f..bf063590 100644 --- a/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb +++ b/docs/sphinx/user_guide/notebooks/08_adding_calculated_fields.ipynb @@ -316,6 +316,7 @@ "- NormalizedFlow\n", "- Seasons\n", "- ForecastLeadTime\n", + "- ForecastLeadTimeBins\n", "- ThresholdValueExceeded\n", "- DayOfYear\n", "\n", diff --git a/src/teehr/models/calculated_fields/row_level.py b/src/teehr/models/calculated_fields/row_level.py index 5252a155..fb379965 100644 --- a/src/teehr/models/calculated_fields/row_level.py +++ b/src/teehr/models/calculated_fields/row_level.py @@ -1,9 +1,9 @@ """Classes representing UDFs.""" import calendar -from typing import List, Union -from pydantic import BaseModel as PydanticBaseModel -from pydantic import Field, ConfigDict +from typing import Union +from pydantic import Field import pandas as pd +from datetime import timedelta import pyspark.sql.types as T from pyspark.sql.functions import pandas_udf import pyspark.sql as ps @@ -202,7 +202,8 @@ def apply_to(self, sdf: ps.DataFrame) -> ps.DataFrame: def func(value_time: pd.Series) -> pd.Series: return value_time.dt.month.apply( lambda x: next( - (season for season, months in self.season_months.items() if x in months), + (season for season, + months in self.season_months.items() if x in months), None ) ) @@ -215,7 +216,7 @@ def func(value_time: pd.Series) -> pd.Series: class ForecastLeadTime(CalculatedFieldABC, CalculatedFieldBaseModel): - """Adds the forecast lead time in seconds from a timestamp column. + """Adds the forecast lead time from a timestamp column. Properties ---------- @@ -257,6 +258,462 @@ def func(value_time: pd.Series, return sdf +class ForecastLeadTimeBins(CalculatedFieldABC, CalculatedFieldBaseModel): + """Adds ID for grouped forecast lead time bins. + + Properties + ---------- + - value_time_field_name: + The name of the column containing the timestamp. + Default: "value_time" + - reference_time_field_name: + The name of the column containing the forecast time. + Default: "reference_time" + - lead_time_field_name: + The name of the column containing the forecast lead time. + Default: "forecast_lead_time" + - output_field_name: + The name of the column to store the lead time bin ID. + Default: "forecast_lead_time_bin" + - bin_size: + Defines how forecast lead times are binned. Accepts pd.Timedelta, + datetime.timedelta, or timedelta strings (e.g., '6 hours', '1 day'). + Three input formats are supported: + + 1. **Single timedelta** (uniform binning): + Creates equal-width bins of the specified duration. + + Examples: + pd.Timedelta(hours=6) + timedelta(hours=6) + '6 hours' + '6h' + + Output bin IDs: + "PT0H_PT6H", "PT6H_PT12H", "PT12H_PT18H", ... + + 2. **List of dicts** (variable binning with auto-generated IDs): + Creates bins with custom ranges. Bin IDs are auto-generated as + ISO 8601 duration ranges. Values can be pd.Timedelta, + datetime.timedelta, or timedelta strings. + + Examples: + [ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': '6 hours', + 'end_exclusive': '12 hours'}, + {'start_inclusive': timedelta(hours=12), + 'end_exclusive': '1 day'}, + {'start_inclusive': '1 day', + 'end_exclusive': '2 days'}, + ] + + Output bin IDs: + "PT0H_PT6H", "PT6H_PT12H", "PT12H_P1D", "P1D_P2D" + + 3. **Dict of dicts** (variable binning with custom IDs): + Creates bins with custom ranges and user-defined bin identifiers. + Values can be pd.Timedelta, datetime.timedelta, or timedelta + strings. + + Examples: + { + 'short_range': {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, + 'medium_range': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': timedelta(days=1)}, + 'long_range': {'start_inclusive': '1 day', + 'end_exclusive': '3 days'}, + } + + Output bin IDs: + "short_range", "medium_range", "long_range" + + Default: pd.Timedelta(days=5) + + Notes + ----- + - Timedelta values can be specified as: + - pd.Timedelta objects (e.g., pd.Timedelta(hours=6)) + - datetime.timedelta objects (e.g., timedelta(hours=6)) + - Strings (e.g., '6 hours', '1 day', '1d 12h', 'PT6H') + - All timedelta inputs are internally converted to pd.Timedelta for + processing. + - Bin ranges are [start_inclusive, end_exclusive), except for the final + bin which is inclusive of all remaining lead times. + - If the maximum lead time in the data exceeds the last user-defined bin, + an overflow bin is automatically created: + - For auto-generated IDs: Uses ISO 8601 duration format + - For custom IDs: Appends "overflow" as the bin ID + - Bin IDs use ISO 8601 duration format (e.g., "PT6H" for 6 hours, "P1DT12H" + for 1 day and 12 hours) for auto-generated bins. + - Custom bin IDs can use any string format. + + Examples + -------- + Uniform 6-hour bins using different input types: + + .. code-block:: python + + # Using pd.Timedelta + fcst_bins = ForecastLeadTimeBins(bin_size=pd.Timedelta(hours=6)) + + # Using datetime.timedelta + from datetime import timedelta + fcst_bins = ForecastLeadTimeBins(bin_size=timedelta(hours=6)) + + # Using string + fcst_bins = ForecastLeadTimeBins(bin_size='6 hours') + + # All create bins: PT0H_PT6H, PT6H_PT12H, PT12H_PT18H, ... + + Variable bins with auto-generated IDs using mixed types: + + .. code-block:: python + + fcst_bins = ForecastLeadTimeBins( + bin_size=[ + {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': '1 day'}, + {'start_inclusive': timedelta(days=1), + 'end_exclusive': '3 days'}, + ] + ) + # Creates bins: PT0H_PT6H, PT6H_P1D, P1D_P3D + + Variable bins with custom IDs using strings: + + .. code-block:: python + + fcst_bins = ForecastLeadTimeBins( + bin_size={ + 'nowcast': {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, + 'short_term': {'start_inclusive': '6 hours', + 'end_exclusive': '1 day'}, + 'medium_term': {'start_inclusive': '1 day', + 'end_exclusive': '5 days'}, + } + ) + # Creates bins: nowcast, short_term, medium_term + """ + + value_time_field_name: str = Field( + default="value_time" + ) + reference_time_field_name: str = Field( + default="reference_time" + ) + lead_time_field_name: str = Field( + default="forecast_lead_time" + ) + output_field_name: str = Field( + default="forecast_lead_time_bin" + ) + bin_size: Union[pd.Timedelta, timedelta, str, list, dict] = Field( + default=pd.Timedelta(days=5) + ) + + @staticmethod + def _validate_bin_size_dict(self) -> Union[pd.Timedelta, list, dict]: + """Validate and normalize bin_size input. + + Validates and converts bin_size to a standardized format: + - Single pd.Timedelta: returns as-is + - List of dicts: validates structure and converts to internal format + - Dict of dicts: validates structure and keeps custom bin IDs + + Returns a normalized structure for internal processing. + """ + def _to_pd_timedelta(value, field_name, context): + """Convert datetime.timedelta or string to pd.Timedelta.""" + if isinstance(value, pd.Timedelta): + return value + elif isinstance(value, timedelta): + return pd.Timedelta(value) + elif isinstance(value, str): + try: + temp = pd.Timedelta(value) + if temp < pd.Timedelta(seconds=1) and \ + temp != pd.Timedelta(0): + raise ValueError( + "Timedelta must be at least 1 second" + ) + return temp + except ValueError as e: + raise ValueError( + f"{context} '{field_name}' has invalid timedelta" + f" string: '{value}'. " + f"Error: {e}" + ) + else: + raise TypeError( + f"{context} '{field_name}' must be pd.Timedelta," + " datetime.timedelta, or a valid timedelta string," + f" got {type(value)}" + ) + + # Single Timedelta - convert if needed + if isinstance(self.bin_size, (pd.Timedelta, timedelta, str)): + return _to_pd_timedelta(self.bin_size, 'bin_size', 'bin_size') + + # List of dicts format + if isinstance(self.bin_size, list): + if not self.bin_size: + raise ValueError("bin_size list cannot be empty") + + # Validate each dict has required keys + for i, bin_dict in enumerate(self.bin_size): + if not isinstance(bin_dict, dict): + raise TypeError( + f"Item {i} in bin_size list must be a dict" + ) + + required_keys = {'start_inclusive', 'end_exclusive'} + if not required_keys.issubset(bin_dict.keys()): + raise ValueError( + f"Item {i} missing required keys. " + f"Must have: {required_keys}" + ) + + # Validate and convert values to pd.Timedelta + start = _to_pd_timedelta( + bin_dict['start_inclusive'], + 'start_inclusive', + f"Item {i}" + ) + end = _to_pd_timedelta( + bin_dict['end_exclusive'], + 'end_exclusive', + f"Item {i}" + ) + + # Convert to internal format: list of tuples (start, end, bin_id) + # For list format, bin_id is None (will be auto-generated as ISO) + normalized = [] + for bin_dict in self.bin_size: + start = _to_pd_timedelta( + bin_dict['start_inclusive'], + 'start_inclusive', + 'bin_dict' + ) + end = _to_pd_timedelta( + bin_dict['end_exclusive'], + 'end_exclusive', + 'bin_dict' + ) + normalized.append((start, end, None)) + + return normalized + + # Dict of dicts format + if isinstance(self.bin_size, dict): + if not self.bin_size: + raise ValueError("bin_size dict cannot be empty") + + # Validate structure + for key, value in self.bin_size.items(): + if not isinstance(key, str): + raise TypeError( + f"Dict keys must be strings (custom bin IDs), got " + f"{type(key)}" + ) + + if not isinstance(value, dict): + raise TypeError( + "Dict values must be dicts with bin specification" + ) + + required_keys = {'start_inclusive', 'end_exclusive'} + if not required_keys.issubset(value.keys()): + raise ValueError( + f"Bin '{key}' missing required keys. Must have: " + f"{required_keys}" + ) + + # Validate and convert to pd.Timedelta + _to_pd_timedelta( + value['start_inclusive'], + 'start_inclusive', + f"Bin '{key}'" + ) + _to_pd_timedelta( + value['end_exclusive'], + 'end_exclusive', + f"Bin '{key}'" + ) + + # Convert to internal format: list of tuples + normalized = [] + for custom_id, bin_dict in self.bin_size.items(): + start = _to_pd_timedelta( + bin_dict['start_inclusive'], + 'start_inclusive', + f"Bin '{custom_id}'" + ) + end = _to_pd_timedelta( + bin_dict['end_exclusive'], + 'end_exclusive', + f"Bin '{custom_id}'" + ) + normalized.append((start, end, custom_id)) + + return normalized + + raise TypeError( + "bin_size must be pd.Timedelta, datetime.timedelta, " + "a valid timedelta string, list of dicts, or dict of dicts" + ) + + @staticmethod + def _add_forecast_lead_time(self, sdf: ps.DataFrame) -> ps.DataFrame: + """Calculate forecast lead time if not already present.""" + if self.lead_time_field_name not in sdf.columns: + flt_cf = ForecastLeadTime( + value_time_field_name=self.value_time_field_name, + reference_time_field_name=self.reference_time_field_name, + output_field_name=self.lead_time_field_name + ) + sdf = flt_cf.apply_to(sdf) + return sdf + + @staticmethod + def _add_forecast_lead_time_bin( + self, + sdf: ps.DataFrame + ) -> ps.DataFrame: + """Add forecast lead time bin column.""" + + def _timedelta_to_iso_duration(td: pd.Timedelta) -> str: + """Convert pd.Timedelta to ISO 8601 duration string.""" + iso_str = td.isoformat() + # Remove trailing 0M0S, 0S, etc. for cleaner output + iso_str = iso_str.replace( + '0M0S', '').replace( + '0S', '').replace( + '0M', '') + # Handle edge case where we removed everything after 'T' + if iso_str.endswith('T'): + iso_str = iso_str[:-1] + 'T0S' + return iso_str + + @pandas_udf(returnType=T.StringType()) + def func(lead_time: pd.Series) -> pd.Series: + # Single Timedelta - uniform binning + if isinstance(self.bin_size, pd.Timedelta): + bin_size_seconds = self.bin_size.total_seconds() + + bin_numbers = ( + lead_time.dt.total_seconds() // bin_size_seconds + ).astype(int) + + bin_ids = pd.Series("", index=lead_time.index) + + for bin_num in bin_numbers.unique(): + bin_mask = bin_numbers == bin_num + + if bin_mask.any(): + start_td = pd.Timedelta( + seconds=bin_num * bin_size_seconds + ) + end_td = pd.Timedelta( + seconds=(bin_num + 1) * bin_size_seconds + ) + + # Convert to ISO duration format + start_iso = _timedelta_to_iso_duration(start_td) + end_iso = _timedelta_to_iso_duration(end_td) + bin_id = f"{start_iso}_{end_iso}" + + bin_ids[bin_mask] = bin_id + + return bin_ids + + # List/Dict format - dynamic binning with explicit ranges + # self.bin_size is now a list of tuples: (start, end, bin_id) + # bin_id is None for auto-generated, or a string for custom + else: + bin_ids = pd.Series("", index=lead_time.index) + lead_time_seconds = lead_time.dt.total_seconds() + + # Check if we need to add an overflow bin + max_lead_time = lead_time.max() + last_bin_end = self.bin_size[-1][1] + + # Create working copy of bin_size + bins_to_use = [] + + # Convert all bins, generating ISO format for None bin_ids + for start_td, end_td, bin_id in self.bin_size: + if bin_id is None: + # Auto-generated: create ISO duration format + start_iso = _timedelta_to_iso_duration(start_td) + end_iso = _timedelta_to_iso_duration(end_td) + final_bin_id = f"{start_iso}_{end_iso}" + else: + # Custom ID: use as-is + final_bin_id = bin_id + + bins_to_use.append((start_td, end_td, final_bin_id)) + + # If max lead time exceeds last bin, create overflow bin + if max_lead_time >= last_bin_end: + overflow_start = last_bin_end + overflow_end = max_lead_time + + # Determine overflow bin_id + if self.bin_size[-1][2] is None: + # Auto-generated format: use ISO duration strings + start_iso = _timedelta_to_iso_duration(overflow_start) + end_iso = _timedelta_to_iso_duration(overflow_end) + overflow_bin_id = f"{start_iso}_{end_iso}" + else: + # Custom ID format: append suffix + overflow_bin_id = "overflow" + + bins_to_use.append( + (overflow_start, overflow_end, overflow_bin_id) + ) + + for i, (start_td, end_td, bin_id) in enumerate(bins_to_use): + start_seconds = start_td.total_seconds() + end_seconds = end_td.total_seconds() + + # Check if this is the last bin (including overflow bin) + is_last_bin = (i == len(bins_to_use) - 1) + + if is_last_bin: + # Last bin is inclusive of end_exclusive + mask = lead_time_seconds >= start_seconds + else: + # All other bins are [start, end) + mask = ( + (lead_time_seconds >= start_seconds) & + (lead_time_seconds < end_seconds) + ) + + if mask.any(): + bin_ids[mask] = bin_id + + return bin_ids + + sdf = sdf.withColumn( + self.output_field_name, + func(self.lead_time_field_name) + ) + return sdf + + def apply_to(self, sdf: ps.DataFrame) -> ps.DataFrame: + """Apply the calculated field to the Spark DataFrame.""" + self.bin_size = self._validate_bin_size_dict(self) + sdf = self._add_forecast_lead_time(self, sdf) + sdf = self._add_forecast_lead_time_bin(self, sdf) + return sdf + + class ThresholdValueExceeded(CalculatedFieldABC, CalculatedFieldBaseModel): """Adds boolean column indicating if the input value exceeds a threshold. @@ -408,6 +865,7 @@ class HourOfYear(CalculatedFieldABC, CalculatedFieldBaseModel): The name of the column to store the month. Default: "hour_of_year" """ + input_field_name: str = Field( default="value_time" ) @@ -461,6 +919,7 @@ class RowLevelCalculatedFields: - NormalizedFlow - Seasons - ForecastLeadTime + - ForecastLeadTimeBins - ThresholdValueExceeded - DayOfYear - HourOfYear @@ -472,6 +931,7 @@ class RowLevelCalculatedFields: NormalizedFlow = NormalizedFlow Seasons = Seasons ForecastLeadTime = ForecastLeadTime + ForecastLeadTimeBins = ForecastLeadTimeBins ThresholdValueExceeded = ThresholdValueExceeded ThresholdValueNotExceeded = ThresholdValueNotExceeded DayOfYear = DayOfYear diff --git a/src/teehr/models/calculated_fields/timeseries_aware.py b/src/teehr/models/calculated_fields/timeseries_aware.py index 97c1cf6c..3623742a 100644 --- a/src/teehr/models/calculated_fields/timeseries_aware.py +++ b/src/teehr/models/calculated_fields/timeseries_aware.py @@ -66,6 +66,7 @@ class AbovePercentileEventDetection(CalculatedFieldABC, CalculatedFieldBaseModel 'unit_name' ] """ + quantile: float = Field( default=0.85 ) diff --git a/tests/data/setup_v0_4_ensemble_study.py b/tests/data/setup_v0_4_ensemble_study.py new file mode 100644 index 00000000..094ea6f5 --- /dev/null +++ b/tests/data/setup_v0_4_ensemble_study.py @@ -0,0 +1,120 @@ +from pathlib import Path +from teehr import Configuration +from teehr.models.filters import TableFilter +from teehr.evaluation.evaluation import Evaluation +from teehr import SignatureTimeseriesGenerators as sts +from teehr import BenchmarkForecastGenerators as bm + +TEST_STUDY_DATA_DIR_v0_4 = Path(Path.cwd(), "tests", "data", "test_study") + + +def setup_v0_4_ensemble_study(tmpdir): + """Create a test evaluation with ensemble forecasts using teehr.""" + usgs_location = Path( + TEST_STUDY_DATA_DIR_v0_4, "geo", "USGS_PlatteRiver_location.parquet" + ) + + secondary_filename = "MEFP.MBRFC.DNVC2LOCAL.SQIN.xml" + secondary_filepath = Path( + TEST_STUDY_DATA_DIR_v0_4, + "timeseries", + secondary_filename + ) + primary_filepath = Path( + TEST_STUDY_DATA_DIR_v0_4, + "timeseries", + "usgs_hefs_06711565.parquet" + ) + + ev = Evaluation(dir_path=tmpdir) + ev.enable_logging() + ev.clone_template() + + ev.locations.load_spatial( + in_path=usgs_location + ) + ev.location_crosswalks.load_csv( + in_path=Path( + TEST_STUDY_DATA_DIR_v0_4, "geo", "hefs_usgs_crosswalk.csv" + ) + ) + ev.configurations.add( + Configuration( + name="MEFP", + type="secondary", + description="MBRFC HEFS Data" + ) + ) + ev.configurations.add( + Configuration( + name="usgs_observations", + type="primary", + description="USGS observed test data" + ) + ) + constant_field_values = { + "unit_name": "ft^3/s", + "variable_name": "streamflow_hourly_inst", + } + ev.secondary_timeseries.load_fews_xml( + in_path=secondary_filepath, + constant_field_values=constant_field_values + ) + ev.primary_timeseries.load_parquet( + in_path=primary_filepath + ) + + # Calculate annual hourly normals from USGS observations. + input_ts = TableFilter() + input_ts.table_name = "primary_timeseries" + + ts_normals = sts.Normals() + ts_normals.temporal_resolution = "hour_of_year" # the default + ts_normals.summary_statistic = "mean" # the default + + ev.generate.signature_timeseries( + method=ts_normals, + input_table_filter=input_ts, + start_datetime="2024-11-19 12:00:00", + end_datetime="2024-11-21 13:00:00", + timestep="1 hour", + fillna=False + ).write() + + # Add reference forecast based on climatology. + ev.configurations.add( + [ + Configuration( + name="benchmark_forecast_hourly_normals", + type="secondary", + description="Reference forecast based on USGS climatology summarized by hour of year" # noqa + ) + ] + ) + ref_fcst = bm.ReferenceForecast() + ref_fcst.aggregate_reference_timeseries = True + + reference_ts = TableFilter() + reference_ts.table_name = "primary_timeseries" + reference_ts.filters = [ + "variable_name = 'streamflow_hour_of_year_mean'", + "unit_name = 'ft^3/s'" + ] + + template_ts = TableFilter() + template_ts.table_name = "secondary_timeseries" + template_ts.filters = [ + "variable_name = 'streamflow_hourly_inst'", + "unit_name = 'ft^3/s'", + "member = '1993'" + ] + ev.generate.benchmark_forecast( + method=ref_fcst, + reference_table_filter=reference_ts, + template_table_filter=template_ts, + output_configuration_name="benchmark_forecast_hourly_normals" + ).write(destination_table="secondary_timeseries") + + ev.joined_timeseries.create(execute_scripts=False) + + return ev diff --git a/tests/evaluations/test_add_udfs.py b/tests/evaluations/test_add_udfs.py index 54008794..acb0c6d4 100644 --- a/tests/evaluations/test_add_udfs.py +++ b/tests/evaluations/test_add_udfs.py @@ -9,11 +9,13 @@ import numpy as np import baseflow import pandas as pd +from datetime import timedelta import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) from data.setup_v0_3_study import setup_v0_3_study # noqa +from data.setup_v0_4_ensemble_study import setup_v0_4_ensemble_study # noqa def test_add_row_udfs_null_reference(tmpdir): @@ -55,6 +57,9 @@ def test_add_row_udfs(tmpdir): sdf = rcf.ForecastLeadTime().apply_to(sdf) _ = sdf.toPandas() + sdf = rcf.ForecastLeadTimeBins().apply_to(sdf) + _ = sdf.toPandas() + sdf = rcf.ThresholdValueExceeded( threshold_field_name="year_2_discharge" ).apply_to(sdf) @@ -125,6 +130,192 @@ def test_add_row_udfs(tmpdir): ev.spark.stop() +def test_forecast_lead_time_bins(tmpdir): + """Test ForecastLeadTimeBins UDF.""" + ev = setup_v0_4_ensemble_study(tmpdir) + + # test with single bin size + fcst_bins_static = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=pd.Timedelta(hours=6) + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_static, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 9 + + # try with dynamic bin sizes that DO encompass full lead time range + bin = [ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': pd.Timedelta(days=2)}, + {'start_inclusive': pd.Timedelta(days=2), + 'end_exclusive': pd.Timedelta(days=3)}, + ] + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin, + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + + # try with dynamic bin sizes that DO NOT encompass full lead time range + bin = [ + {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + ] + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin, + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 6 + assert 'P1DT12H_P2DT0H' in [row['forecast_lead_time_bin'] for row in + sorted_sdf.select( + 'forecast_lead_time_bin' + ).distinct().collect()] + + # try with dynamic bin sizes w/ string dict keys that DO encompass full + # lead time range + bin = { + 'bin_1': {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + 'bin_2': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + 'bin_3': {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + 'bin_4': {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + 'bin_5': {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + 'bin_6': {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': pd.Timedelta(days=2)}, + 'bin_7': {'start_inclusive': pd.Timedelta(days=2), + 'end_exclusive': pd.Timedelta(days=3)}, + } + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + + # try with dynamic bin sizes w/ string dict keys that DO NOT encompass + # full lead time range + bin = { + 'bin_1': {'start_inclusive': pd.Timedelta(hours=0), + 'end_exclusive': pd.Timedelta(hours=6)}, + 'bin_2': {'start_inclusive': pd.Timedelta(hours=6), + 'end_exclusive': pd.Timedelta(hours=12)}, + 'bin_3': {'start_inclusive': pd.Timedelta(hours=12), + 'end_exclusive': pd.Timedelta(hours=18)}, + 'bin_4': {'start_inclusive': pd.Timedelta(hours=18), + 'end_exclusive': pd.Timedelta(days=1)}, + 'bin_5': {'start_inclusive': pd.Timedelta(days=1), + 'end_exclusive': pd.Timedelta(days=1, hours=12)}, + 'bin_6': {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': pd.Timedelta(days=2)}, + } + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + assert 'overflow' in [row['forecast_lead_time_bin'] for row in + sorted_sdf.select( + 'forecast_lead_time_bin' + ).distinct().collect()] + + # try mixed type dynamic bin sizes w/ string dict keys that DO encompass + # the full lead time range + bin = { + 'bin_1': {'start_inclusive': '0 hours', + 'end_exclusive': '6 hours'}, + 'bin_2': {'start_inclusive': pd.Timedelta('6 hours'), + 'end_exclusive': pd.Timedelta(hours=12)}, + 'bin_3': {'start_inclusive': timedelta(hours=12), + 'end_exclusive': timedelta(hours=18)}, + 'bin_4': {'start_inclusive': '18 hours', + 'end_exclusive': pd.Timedelta('1 days')}, + 'bin_5': {'start_inclusive': '1 days', + 'end_exclusive': timedelta(days=1, hours=12)}, + 'bin_6': {'start_inclusive': pd.Timedelta(days=1, hours=12), + 'end_exclusive': '2 days'}, + 'bin_7': {'start_inclusive': timedelta(days=2), + 'end_exclusive': '3 days'}, + } + fcst_bins_dynamic = teehr.RowLevelCalculatedFields.ForecastLeadTimeBins( + bin_size=bin + ) + sdf = ev.joined_timeseries.add_calculated_fields([ + fcst_bins_dynamic, + ]).to_sdf() + sorted_sdf = sdf.orderBy( + "primary_location_id", + "configuration_name", + "member", + "reference_time", + "value_time" + ) + assert sorted_sdf.select('forecast_lead_time_bin').distinct().count() == 7 + + def test_add_timeseries_udfs(tmpdir): """Test adding a timeseries aware UDF.""" # utilize e0_2_location_example from s3 to satisfy baseflow POR reqs @@ -276,7 +467,7 @@ def test_add_timeseries_udfs(tmpdir): skip_event_id=True ) sdf = ped.apply_to(sdf) - num_event_timesteps = sdf.filter(sdf.event_above == True).count() + num_event_timesteps = sdf.filter(sdf.event_above).count() assert num_event_timesteps == 14823 # test percentile event detection (return quantile value) @@ -340,7 +531,9 @@ def test_location_event_detection(tmpdir): ped = tcf.AbovePercentileEventDetection() sdf = ev.metrics.add_calculated_fields(ped).query( - group_by=["configuration_name", "primary_location_id", "event_above_id"], + group_by=["configuration_name", + "primary_location_id", + "event_above_id"], include_metrics=[ teehr.Signatures.Maximum( input_field_names=["primary_value"], @@ -380,6 +573,12 @@ def test_location_event_detection(tmpdir): dir=tempdir ) ) + test_forecast_lead_time_bins( + tempfile.mkdtemp( + prefix="1b-", + dir=tempdir + ) + ) test_add_timeseries_udfs( tempfile.mkdtemp( prefix="2-",