Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ module = [
"plotly.*",
"dash.*",
"dash_bootstrap_components.*",
"pyarrow",
]

[[tool.mypy.overrides]]
Expand Down
29 changes: 15 additions & 14 deletions src/stride/dsgrid_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from pathlib import Path
from typing import Any

from stride.models import Scenario

import duckdb
import pyarrow as pa
from chronify.exceptions import InvalidParameter
from dsgrid.dimension.base_models import DatasetDimensionRequirements
from dsgrid.config.mapping_tables import MappingTableModel
from dsgrid.config.registration_models import DimensionType
from dsgrid.query.models import DimensionReferenceModel, make_dataset_query
from dsgrid.utils.files import load_json_file
import duckdb
from dsgrid.query.query_submitter import (
DatasetQuerySubmitter,
)
Expand All @@ -21,6 +20,8 @@
from dsgrid.registry.registry_manager import RegistryManager
from loguru import logger

from stride.models import Scenario


def deploy_to_dsgrid_registry(
registry_path: Path,
Expand Down Expand Up @@ -332,17 +333,17 @@ def _query_and_create_table(
)
# Use Arrow transfer instead of toPandas() for better performance
arrow_table = df.relation.arrow() # noqa: F841
# Convert model_year from string to integer if needed
if "model_year" in arrow_table.schema.names:
import pyarrow as pa # type: ignore[import-untyped]

field_index = arrow_table.schema.get_field_index("model_year")
if pa.types.is_string(arrow_table.schema.field(field_index).type):
arrow_table = arrow_table.set_column(
field_index,
"model_year",
arrow_table.column("model_year").cast(pa.int64()),
)

# Convert year columns from string to integer if needed
for year_col in ("model_year", "weather_year"):
if year_col in arrow_table.schema.names:
field_index = arrow_table.schema.get_field_index(year_col)
if pa.types.is_string(arrow_table.schema.field(field_index).type):
arrow_table = arrow_table.set_column(
field_index,
year_col,
arrow_table.column(year_col).cast(pa.int64()),
)
con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM arrow_table")
logger.info("Created table {} from mapped dataset.", table_name)

Expand Down
49 changes: 44 additions & 5 deletions src/stride/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,16 +651,55 @@ def get_energy_projection(self, scenario: str | None = None) -> DuckDBPyRelation
)

def show_data_table(self, scenario: str, data_table_id: str, limit: int = 20) -> None:
"""Print a limited number of rows of the data table to the console."""
table = make_dsgrid_data_table_name(scenario, data_table_id)
self._show_table(table, limit=limit)
"""Print a limited number of rows of the data table to the console.

def _show_table(self, table: str, limit: int = 20) -> None:
rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,))
Data is filtered by the project's configuration:
- geography column filtered by project's country
- model_year column filtered by project's model years
- weather_year column filtered by project's weather year
"""
table = make_dsgrid_data_table_name(scenario, data_table_id)
self._show_table(table, limit=limit, filter_by_project=True)

def _show_table(self, table: str, limit: int = 20, filter_by_project: bool = False) -> None:
if filter_by_project:
columns = self._get_table_columns(table)
conditions = []
params: list[Any] = []

if "geography" in columns:
conditions.append("geography = ?")
params.append(self._config.country)

if "model_year" in columns:
model_years = self._config.list_model_years()
placeholders = ", ".join("?" for _ in model_years)
conditions.append(f"model_year IN ({placeholders})")
params.extend(model_years)

if "weather_year" in columns:
conditions.append("weather_year = ?")
params.append(self._config.weather_year)

if conditions:
where_clause = " AND ".join(conditions)
params.append(limit)
rel = self._con.sql(
f"SELECT * FROM {table} WHERE {where_clause} LIMIT ?",
params=params,
)
else:
rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,))
else:
rel = self._con.sql(f"SELECT * FROM {table} LIMIT ?", params=(limit,))
# DuckDB doesn't seem to provide a way to change the number of rows displayed.
# If this is an issue, we could redirect to Pandas and customize the output.
print(rel)

def _get_table_columns(self, table: str) -> list[str]:
"""Get the list of column names for a table."""
return [x[0] for x in self._con.sql(f"DESCRIBE {table}").fetchall()]

def get_table_overrides(self) -> dict[str, list[str]]:
"""Return a dictionary of tables being overridden for each scenario."""
overrides: dict[str, list[str]] = defaultdict(list)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,55 @@ def test_show_data_table(default_project: Project) -> None:
assert result.exit_code == 0


def test_show_data_table_filters_by_project_config(default_project: Project) -> None:
"""Test that data-tables show filters by the project's configuration.

The test project uses country_1, model years 2025-2050 (step 5), and weather_year 2018.
Data should be filtered to only show matching records.
"""
project = default_project
runner = CliRunner()

# Verify project configuration
assert project.config.country == "country_1"
assert project.config.start_year == 2025
assert project.config.end_year == 2050
assert project.config.step_year == 5
assert project.config.weather_year == 2018
model_years = project.config.list_model_years()
assert model_years == [2025, 2030, 2035, 2040, 2045, 2050]

# Test GDP table - should filter by country and model_year
result = runner.invoke(cli, ["data-tables", "show", str(project.path), "gdp", "-l", "100"])
assert result.exit_code == 0
# country_1 should appear (project's country)
assert "country_1" in result.stdout
# country_2 should NOT appear (filtered out)
assert "country_2" not in result.stdout
# Only project's model years should appear
for year in model_years:
assert str(year) in result.stdout
# Years outside the range should not appear (e.g., 1990, 2000, 2010)
assert "1990" not in result.stdout
assert "2000" not in result.stdout
assert "2010" not in result.stdout

# Test weather_bait table - should filter by country and weather_year
result = runner.invoke(
cli, ["data-tables", "show", str(project.path), "weather_bait", "-l", "100"]
)
assert result.exit_code == 0
# country_1 should appear (project's country)
assert "country_1" in result.stdout
# country_2 should NOT appear (filtered out)
assert "country_2" not in result.stdout
# Only project's weather_year (2018) should appear
assert "2018" in result.stdout
# Other weather years should not appear
assert "1980" not in result.stdout
assert "2020" not in result.stdout


def test_show_calculated_table(default_project: Project) -> None:
project = default_project
runner = CliRunner()
Expand Down