Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ h5py = ">=3.12.1,<4"
pyarrow = ">=15.0.0,<21"
httpx = ">=0.25.1,<1"
pandas = ">=2.2.0,<3"
pyspark = {extras = ["pandas-on-spark"], version = ">=4,<5"}
pyspark = {extras = ["pandas-on-spark"], version = "4.0.0"}
dataretrieval = ">=1.0.9,<2"
numba = ">=0.60.0,<1"
arch = ">=7.0.0,<8"
Expand Down
3 changes: 1 addition & 2 deletions src/teehr/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def fetch(self) -> Fetch:
@property
def metrics(self) -> Metrics:
"""The metrics component class for calculating performance metrics."""
cls = Metrics(self)
return cls()
return Metrics(self)

@property
def units(self) -> UnitTable:
Expand Down
2 changes: 2 additions & 0 deletions src/teehr/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __call__(
>>> ev = teehr.Evaluation()
>>> metrics = ev.metrics(table_name="primary_timeseries")
"""
logger.info(f"Initializing Metrics for table: {table_name}.{namespace_name or ''}{'.' if namespace_name else ''}{catalog_name or ''}")

self.table_name = table_name
self.table = self._ev.table(
table_name=table_name,
Expand Down
1 change: 1 addition & 0 deletions src/teehr/evaluation/tables/base_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __call__(
catalog_name: Union[str, None] = None
) -> "Table":
"""Initialize the Table class."""
logger.info(f"Initializing Table for table: {table_name}.{namespace_name or ''}{'.' if namespace_name else ''}{catalog_name or ''}")
self.table_name = table_name
self.sdf = None
tbl_props = TBLPROPERTIES.get(table_name)
Expand Down
26 changes: 13 additions & 13 deletions tests/query/test_metrics_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_executing_deterministic_metrics(tmpdir):
# Get the currently available fields to use in the query.
flds = ev.joined_timeseries.field_enum()

metrics_df = ev.metrics.query(
metrics_df = ev.metrics().query(
include_metrics=include_nonconditional_metrics,
group_by=[flds.primary_location_id],
order_by=[flds.primary_location_id],
Expand All @@ -75,7 +75,7 @@ def test_executing_deterministic_metrics(tmpdir):
if callable(func) and func().attrs.get('requires_threshold_field', True) # noqa
]

metrics_df = ev.metrics.add_calculated_fields([
metrics_df = ev.metrics().add_calculated_fields([
tcf.AbovePercentileEventDetection(
skip_event_id=True,
add_quantile_field=True,
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_executing_signatures(tmpdir):
# Get the currently available fields to use in the query.
flds = ev.joined_timeseries.field_enum()

metrics_df = ev.metrics.query(
metrics_df = ev.metrics().query(
include_metrics=include_all_metrics,
group_by=[flds.primary_location_id],
order_by=[flds.primary_location_id],
Expand Down Expand Up @@ -142,7 +142,7 @@ def test_metrics_filter_and_geometry(tmpdir):
)
]

metrics_df = ev.metrics.query(
metrics_df = ev.metrics().query(
include_metrics=include_metrics,
group_by=[flds.primary_location_id],
order_by=[flds.primary_location_id],
Expand All @@ -161,7 +161,7 @@ def test_metric_chaining(tmpdir):
ev = setup_v0_3_study(tmpdir)

# Test chaining.
metrics_df = ev.metrics.query(
metrics_df = ev.metrics().query(
order_by=["primary_location_id", "month"],
group_by=["primary_location_id", "month"],
include_metrics=[
Expand Down Expand Up @@ -302,7 +302,7 @@ def test_ensemble_metrics(tmpdir):
crps.reference_configuration = "benchmark_forecast_hourly_normals"

include_metrics = [crps]
metrics_df = ev.metrics.query(
metrics_df = ev.metrics().query(
include_metrics=include_metrics,
group_by=[
"primary_location_id",
Expand Down Expand Up @@ -341,21 +341,21 @@ def test_metrics_transforms(tmpdir):
mvtd_t.transform = 'log'

# get metrics_df
metrics_df_tansformed_e = test_eval.metrics.query(
metrics_df_tansformed_e = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
kge_t_e,
mvtd_t
]
).to_pandas()
metrics_df_transformed = test_eval.metrics.query(
metrics_df_transformed = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
kge_t,
mvtd_t
]
).to_pandas()
metrics_df = test_eval.metrics.query(
metrics_df = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
kge,
Expand Down Expand Up @@ -396,7 +396,7 @@ def test_metrics_transforms(tmpdir):
)

# get metrics df control and assert divide by zero occurs
metrics_df_e_control = test_eval.metrics.query(
metrics_df_e_control = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
r2,
Expand All @@ -407,7 +407,7 @@ def test_metrics_transforms(tmpdir):
assert np.isnan(metrics_df_e_control.pearson_correlation.values).all()

# get metrics df test and ensure no divide by zero occurs
metrics_df_e_test = test_eval.metrics.query(
metrics_df_e_test = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
r2_e,
Expand Down Expand Up @@ -436,7 +436,7 @@ def test_metrics_transforms(tmpdir):
)

# get metrics df control and assert divide by zero occurs
metrics_df_e_control = test_eval.metrics.query(
metrics_df_e_control = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
r2,
Expand All @@ -447,7 +447,7 @@ def test_metrics_transforms(tmpdir):
assert np.isnan(metrics_df_e_control.pearson_correlation.values).all()

# get metrics df test and ensure no divide by zero occurs
metrics_df_e_test = test_eval.metrics.query(
metrics_df_e_test = test_eval.metrics().query(
group_by=["primary_location_id", "configuration_name"],
include_metrics=[
r2_e,
Expand Down