From db1b9fd75a2056c7feeadaa08c46135013a35afc Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 30 Jan 2026 09:37:43 -0700 Subject: [PATCH 1/4] Fix show data-table command The CLI command stride data-tables show was not filtering the table for the project's country. --- src/stride/project.py | 22 ++++++++++++++++++---- tests/test_project.py | 19 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/stride/project.py b/src/stride/project.py index e927b29..f8525e4 100644 --- a/src/stride/project.py +++ b/src/stride/project.py @@ -651,16 +651,30 @@ def get_energy_projection(self, scenario: str | None = None) -> DuckDBPyRelation ) def show_data_table(self, scenario: str, data_table_id: str, limit: int = 20) -> None: - """Print a limited number of rows of the data table to the console.""" + """Print a limited number of rows of the data table to the console. + + Data is filtered by the project's country if the table has a geography column. + """ table = make_dsgrid_data_table_name(scenario, data_table_id) - self._show_table(table, limit=limit) + self._show_table(table, limit=limit, filter_by_country=True) - def _show_table(self, table: str, limit: int = 20) -> None: - rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) + def _show_table(self, table: str, limit: int = 20, filter_by_country: bool = False) -> None: + if filter_by_country and self._table_has_geography_column(table): + rel = self._con.sql( + f"SELECT * FROM {table} WHERE geography = ? LIMIT ?", + params=(self._config.country, limit), + ) + else: + rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) # DuckDB doesn't seem to provide a way to change the number of rows displayed. # If this is an issue, we could redirect to Pandas and customize the output. print(rel) + def _table_has_geography_column(self, table: str) -> bool: + """Check if a table has a geography column.""" + columns = [x[0] for x in self._con.sql(f"DESCRIBE {table}").fetchall()] + return "geography" in columns + def get_table_overrides(self) -> dict[str, list[str]]: """Return a dictionary of tables being overridden for each scenario.""" overrides: dict[str, list[str]] = defaultdict(list) diff --git a/tests/test_project.py b/tests/test_project.py index 47fc3c7..b2b9626 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -49,6 +49,25 @@ def test_show_data_table(default_project: Project) -> None: assert result.exit_code == 0 +def test_show_data_table_filters_by_country(default_project: Project) -> None: + """Test that data-tables show filters by the project's country. + + The test project uses country_1, so gdp data should only show country_1, + not country_2 (which also exists in the global-test dataset). + """ + project = default_project + runner = CliRunner() + # The test project is configured for country_1 + assert project.config.country == "country_1" + # Show gdp data - should only show country_1 data + result = runner.invoke(cli, ["data-tables", "show", str(project.path), "gdp", "-l", "100"]) + assert result.exit_code == 0 + # country_1 should appear in the output (it's the project's country) + assert "country_1" in result.stdout + # country_2 should NOT appear - the data should be filtered by project country + assert "country_2" not in result.stdout + + def test_show_calculated_table(default_project: Project) -> None: project = default_project runner = CliRunner() From 572fa53b837aa8c2555aa192265c3091a7949717 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 30 Jan 2026 12:31:34 -0700 Subject: [PATCH 2/4] Filter correct years --- src/stride/project.py | 53 ++++++++++++++++++++++++++++++++----------- tests/test_project.py | 46 ++++++++++++++++++++++++++++++------- 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/src/stride/project.py b/src/stride/project.py index f8525e4..6b73ad7 100644 --- a/src/stride/project.py +++ b/src/stride/project.py @@ -653,27 +653,54 @@ def get_energy_projection(self, scenario: str | None = None) -> DuckDBPyRelation def show_data_table(self, scenario: str, data_table_id: str, limit: int = 20) -> None: """Print a limited number of rows of the data table to the console. - Data is filtered by the project's country if the table has a geography column. + Data is filtered by the project's configuration: + - geography column filtered by project's country + - model_year column filtered by project's model years + - weather_year column filtered by project's weather year """ table = make_dsgrid_data_table_name(scenario, data_table_id) - self._show_table(table, limit=limit, filter_by_country=True) - - def _show_table(self, table: str, limit: int = 20, filter_by_country: bool = False) -> None: - if filter_by_country and self._table_has_geography_column(table): - rel = self._con.sql( - f"SELECT * FROM {table} WHERE geography = ? LIMIT ?", - params=(self._config.country, limit), - ) + self._show_table(table, limit=limit, filter_by_project=True) + + def _show_table(self, table: str, limit: int = 20, filter_by_project: bool = False) -> None: + if filter_by_project: + columns = self._get_table_columns(table) + conditions = [] + params: list[Any] = [] + + if "geography" in columns: + conditions.append("geography = ?") + params.append(self._config.country) + + if "model_year" in columns: + model_years = self._config.list_model_years() + placeholders = ", ".join("?" for _ in model_years) + conditions.append(f"model_year IN ({placeholders})") + params.extend(model_years) + + if "weather_year" in columns: + # weather_year may be stored as string or int depending on the table + conditions.append("(weather_year = ? OR weather_year = ?)") + params.append(str(self._config.weather_year)) + params.append(self._config.weather_year) + + if conditions: + where_clause = " AND ".join(conditions) + params.append(limit) + rel = self._con.sql( + f"SELECT * FROM {table} WHERE {where_clause} LIMIT ?", + params=params, + ) + else: + rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) else: rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) # DuckDB doesn't seem to provide a way to change the number of rows displayed. # If this is an issue, we could redirect to Pandas and customize the output. print(rel) - def _table_has_geography_column(self, table: str) -> bool: - """Check if a table has a geography column.""" - columns = [x[0] for x in self._con.sql(f"DESCRIBE {table}").fetchall()] - return "geography" in columns + def _get_table_columns(self, table: str) -> list[str]: + """Get the list of column names for a table.""" + return [x[0] for x in self._con.sql(f"DESCRIBE {table}").fetchall()] def get_table_overrides(self) -> dict[str, list[str]]: """Return a dictionary of tables being overridden for each scenario.""" diff --git a/tests/test_project.py b/tests/test_project.py index b2b9626..c8bf446 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -49,23 +49,53 @@ def test_show_data_table(default_project: Project) -> None: assert result.exit_code == 0 -def test_show_data_table_filters_by_country(default_project: Project) -> None: - """Test that data-tables show filters by the project's country. +def test_show_data_table_filters_by_project_config(default_project: Project) -> None: + """Test that data-tables show filters by the project's configuration. - The test project uses country_1, so gdp data should only show country_1, - not country_2 (which also exists in the global-test dataset). + The test project uses country_1, model years 2025-2050 (step 5), and weather_year 2018. + Data should be filtered to only show matching records. """ project = default_project runner = CliRunner() - # The test project is configured for country_1 + + # Verify project configuration assert project.config.country == "country_1" - # Show gdp data - should only show country_1 data + assert project.config.start_year == 2025 + assert project.config.end_year == 2050 + assert project.config.step_year == 5 + assert project.config.weather_year == 2018 + model_years = project.config.list_model_years() + assert model_years == [2025, 2030, 2035, 2040, 2045, 2050] + + # Test GDP table - should filter by country and model_year result = runner.invoke(cli, ["data-tables", "show", str(project.path), "gdp", "-l", "100"]) assert result.exit_code == 0 - # country_1 should appear in the output (it's the project's country) + # country_1 should appear (project's country) + assert "country_1" in result.stdout + # country_2 should NOT appear (filtered out) + assert "country_2" not in result.stdout + # Only project's model years should appear + for year in model_years: + assert str(year) in result.stdout + # Years outside the range should not appear (e.g., 1990, 2000, 2010) + assert "1990" not in result.stdout + assert "2000" not in result.stdout + assert "2010" not in result.stdout + + # Test weather_bait table - should filter by country and weather_year + result = runner.invoke( + cli, ["data-tables", "show", str(project.path), "weather_bait", "-l", "100"] + ) + assert result.exit_code == 0 + # country_1 should appear (project's country) assert "country_1" in result.stdout - # country_2 should NOT appear - the data should be filtered by project country + # country_2 should NOT appear (filtered out) assert "country_2" not in result.stdout + # Only project's weather_year (2018) should appear + assert "2018" in result.stdout + # Other weather years should not appear + assert "1980" not in result.stdout + assert "2020" not in result.stdout def test_show_calculated_table(default_project: Project) -> None: From b0b57dcb8d4fc031982a699f7ca1217512d216c1 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 30 Jan 2026 16:00:34 -0700 Subject: [PATCH 3/4] Fix handling of years --- src/stride/dsgrid_integration.py | 29 +++++++++++++++-------------- src/stride/project.py | 4 +--- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/stride/dsgrid_integration.py b/src/stride/dsgrid_integration.py index df91bc0..10fa926 100644 --- a/src/stride/dsgrid_integration.py +++ b/src/stride/dsgrid_integration.py @@ -4,15 +4,14 @@ from pathlib import Path from typing import Any -from stride.models import Scenario - +import duckdb +import pyarrow as pa from chronify.exceptions import InvalidParameter from dsgrid.dimension.base_models import DatasetDimensionRequirements from dsgrid.config.mapping_tables import MappingTableModel from dsgrid.config.registration_models import DimensionType from dsgrid.query.models import DimensionReferenceModel, make_dataset_query from dsgrid.utils.files import load_json_file -import duckdb from dsgrid.query.query_submitter import ( DatasetQuerySubmitter, ) @@ -21,6 +20,8 @@ from dsgrid.registry.registry_manager import RegistryManager from loguru import logger +from stride.models import Scenario + def deploy_to_dsgrid_registry( registry_path: Path, @@ -332,17 +333,17 @@ def _query_and_create_table( ) # Use Arrow transfer instead of toPandas() for better performance arrow_table = df.relation.arrow() # noqa: F841 - # Convert model_year from string to integer if needed - if "model_year" in arrow_table.schema.names: - import pyarrow as pa # type: ignore[import-untyped] - - field_index = arrow_table.schema.get_field_index("model_year") - if pa.types.is_string(arrow_table.schema.field(field_index).type): - arrow_table = arrow_table.set_column( - field_index, - "model_year", - arrow_table.column("model_year").cast(pa.int64()), - ) + + # Convert year columns from string to integer if needed + for year_col in ("model_year", "weather_year"): + if year_col in arrow_table.schema.names: + field_index = arrow_table.schema.get_field_index(year_col) + if pa.types.is_string(arrow_table.schema.field(field_index).type): + arrow_table = arrow_table.set_column( + field_index, + year_col, + arrow_table.column(year_col).cast(pa.int64()), + ) con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM arrow_table") logger.info("Created table {} from mapped dataset.", table_name) diff --git a/src/stride/project.py b/src/stride/project.py index 6b73ad7..acfac21 100644 --- a/src/stride/project.py +++ b/src/stride/project.py @@ -678,9 +678,7 @@ def _show_table(self, table: str, limit: int = 20, filter_by_project: bool = Fal params.extend(model_years) if "weather_year" in columns: - # weather_year may be stored as string or int depending on the table - conditions.append("(weather_year = ? OR weather_year = ?)") - params.append(str(self._config.weather_year)) + conditions.append("weather_year = ?") params.append(self._config.weather_year) if conditions: From 44de2006913f9868fa5eeb5bdc4aa5426b8ed08e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Fri, 30 Jan 2026 16:33:49 -0700 Subject: [PATCH 4/4] Fix mypy --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 0d921a0..fd7a4d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ module = [ "plotly.*", "dash.*", "dash_bootstrap_components.*", + "pyarrow", ] [[tool.mypy.overrides]]