diff --git a/pyproject.toml b/pyproject.toml index 0d921a0..fd7a4d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,6 +83,7 @@ module = [ "plotly.*", "dash.*", "dash_bootstrap_components.*", + "pyarrow", ] [[tool.mypy.overrides]] diff --git a/src/stride/dsgrid_integration.py b/src/stride/dsgrid_integration.py index df91bc0..10fa926 100644 --- a/src/stride/dsgrid_integration.py +++ b/src/stride/dsgrid_integration.py @@ -4,15 +4,14 @@ from pathlib import Path from typing import Any -from stride.models import Scenario - +import duckdb +import pyarrow as pa from chronify.exceptions import InvalidParameter from dsgrid.dimension.base_models import DatasetDimensionRequirements from dsgrid.config.mapping_tables import MappingTableModel from dsgrid.config.registration_models import DimensionType from dsgrid.query.models import DimensionReferenceModel, make_dataset_query from dsgrid.utils.files import load_json_file -import duckdb from dsgrid.query.query_submitter import ( DatasetQuerySubmitter, ) @@ -21,6 +20,8 @@ from dsgrid.registry.registry_manager import RegistryManager from loguru import logger +from stride.models import Scenario + def deploy_to_dsgrid_registry( registry_path: Path, @@ -332,17 +333,17 @@ def _query_and_create_table( ) # Use Arrow transfer instead of toPandas() for better performance arrow_table = df.relation.arrow() # noqa: F841 - # Convert model_year from string to integer if needed - if "model_year" in arrow_table.schema.names: - import pyarrow as pa # type: ignore[import-untyped] - - field_index = arrow_table.schema.get_field_index("model_year") - if pa.types.is_string(arrow_table.schema.field(field_index).type): - arrow_table = arrow_table.set_column( - field_index, - "model_year", - arrow_table.column("model_year").cast(pa.int64()), - ) + + # Convert year columns from string to integer if needed + for year_col in ("model_year", "weather_year"): + if year_col in arrow_table.schema.names: + field_index = arrow_table.schema.get_field_index(year_col) + if pa.types.is_string(arrow_table.schema.field(field_index).type): + arrow_table = arrow_table.set_column( + field_index, + year_col, + arrow_table.column(year_col).cast(pa.int64()), + ) con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM arrow_table") logger.info("Created table {} from mapped dataset.", table_name) diff --git a/src/stride/project.py b/src/stride/project.py index e927b29..acfac21 100644 --- a/src/stride/project.py +++ b/src/stride/project.py @@ -651,16 +651,55 @@ def get_energy_projection(self, scenario: str | None = None) -> DuckDBPyRelation ) def show_data_table(self, scenario: str, data_table_id: str, limit: int = 20) -> None: - """Print a limited number of rows of the data table to the console.""" - table = make_dsgrid_data_table_name(scenario, data_table_id) - self._show_table(table, limit=limit) + """Print a limited number of rows of the data table to the console. - def _show_table(self, table: str, limit: int = 20) -> None: - rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) + Data is filtered by the project's configuration: + - geography column filtered by project's country + - model_year column filtered by project's model years + - weather_year column filtered by project's weather year + """ + table = make_dsgrid_data_table_name(scenario, data_table_id) + self._show_table(table, limit=limit, filter_by_project=True) + + def _show_table(self, table: str, limit: int = 20, filter_by_project: bool = False) -> None: + if filter_by_project: + columns = self._get_table_columns(table) + conditions = [] + params: list[Any] = [] + + if "geography" in columns: + conditions.append("geography = ?") + params.append(self._config.country) + + if "model_year" in columns: + model_years = self._config.list_model_years() + placeholders = ", ".join("?" for _ in model_years) + conditions.append(f"model_year IN ({placeholders})") + params.extend(model_years) + + if "weather_year" in columns: + conditions.append("weather_year = ?") + params.append(self._config.weather_year) + + if conditions: + where_clause = " AND ".join(conditions) + params.append(limit) + rel = self._con.sql( + f"SELECT * FROM {table} WHERE {where_clause} LIMIT ?", + params=params, + ) + else: + rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) + else: + rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,)) # DuckDB doesn't seem to provide a way to change the number of rows displayed. # If this is an issue, we could redirect to Pandas and customize the output. print(rel) + def _get_table_columns(self, table: str) -> list[str]: + """Get the list of column names for a table.""" + return [x[0] for x in self._con.sql(f"DESCRIBE {table}").fetchall()] + def get_table_overrides(self) -> dict[str, list[str]]: """Return a dictionary of tables being overridden for each scenario.""" overrides: dict[str, list[str]] = defaultdict(list) diff --git a/tests/test_project.py b/tests/test_project.py index 47fc3c7..c8bf446 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -49,6 +49,55 @@ def test_show_data_table(default_project: Project) -> None: assert result.exit_code == 0 +def test_show_data_table_filters_by_project_config(default_project: Project) -> None: + """Test that data-tables show filters by the project's configuration. + + The test project uses country_1, model years 2025-2050 (step 5), and weather_year 2018. + Data should be filtered to only show matching records. + """ + project = default_project + runner = CliRunner() + + # Verify project configuration + assert project.config.country == "country_1" + assert project.config.start_year == 2025 + assert project.config.end_year == 2050 + assert project.config.step_year == 5 + assert project.config.weather_year == 2018 + model_years = project.config.list_model_years() + assert model_years == [2025, 2030, 2035, 2040, 2045, 2050] + + # Test GDP table - should filter by country and model_year + result = runner.invoke(cli, ["data-tables", "show", str(project.path), "gdp", "-l", "100"]) + assert result.exit_code == 0 + # country_1 should appear (project's country) + assert "country_1" in result.stdout + # country_2 should NOT appear (filtered out) + assert "country_2" not in result.stdout + # Only project's model years should appear + for year in model_years: + assert str(year) in result.stdout + # Years outside the range should not appear (e.g., 1990, 2000, 2010) + assert "1990" not in result.stdout + assert "2000" not in result.stdout + assert "2010" not in result.stdout + + # Test weather_bait table - should filter by country and weather_year + result = runner.invoke( + cli, ["data-tables", "show", str(project.path), "weather_bait", "-l", "100"] + ) + assert result.exit_code == 0 + # country_1 should appear (project's country) + assert "country_1" in result.stdout + # country_2 should NOT appear (filtered out) + assert "country_2" not in result.stdout + # Only project's weather_year (2018) should appear + assert "2018" in result.stdout + # Other weather years should not appear + assert "1980" not in result.stdout + assert "2020" not in result.stdout + + def test_show_calculated_table(default_project: Project) -> None: project = default_project runner = CliRunner()