From 184bd1ad9f4e9ea8315622f614d9236934dfc025 Mon Sep 17 00:00:00 2001 From: Nicholas Karlson Date: Tue, 20 Jan 2026 11:49:13 -0800 Subject: [PATCH] Track D: add dataset loaders (friendly errors) --- src/pystatsv1/trackd/loaders.py | 164 ++++++++++++++++++ ...test_trackd_loaders_errors_are_friendly.py | 43 +++++ 2 files changed, 207 insertions(+) create mode 100644 src/pystatsv1/trackd/loaders.py create mode 100644 tests/test_trackd_loaders_errors_are_friendly.py diff --git a/src/pystatsv1/trackd/loaders.py b/src/pystatsv1/trackd/loaders.py new file mode 100644 index 0000000..a0714ed --- /dev/null +++ b/src/pystatsv1/trackd/loaders.py @@ -0,0 +1,164 @@ +"""Track D dataset loaders. + +These helpers centralize the repetitive "find datadir + read CSV + friendly errors" +logic used by Track D chapter runner scripts and (later) BYOD adapters. + +This module is intentionally small and stable. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Mapping, Sequence + +from ._errors import TrackDDataError +from ._types import DataFrame, DataFrames, PathLike +from .csvio import read_csv_required + + +def resolve_datadir(datadir: PathLike | None) -> Path: + """Resolve and validate a Track D data directory. + + Parameters + ---------- + datadir: + A path to the directory containing Track D input CSV tables. + + Returns + ------- + pathlib.Path + The validated data directory path. + + Raises + ------ + TrackDDataError + If the directory is missing or not a folder. + """ + if datadir is None: + raise TrackDDataError( + "Data directory is required.\n" + "Hint: pass --datadir to the chapter runner, or set DATADIR in the " + "workbook Makefile." + ) + + p = Path(datadir).expanduser() + + if not p.exists(): + raise TrackDDataError( + f"Data directory not found: {p}.\n" + "Hint: confirm the path exists, then try again." + ) + if not p.is_dir(): + raise TrackDDataError( + f"Data directory is not a folder: {p}.\n" + "Hint: pass a folder path containing your exported CSV tables." + ) + return p + + +def load_table( + datadir: PathLike | None, + filename: str, + *, + required_cols: Sequence[str] | None = None, + parse_dates: Sequence[str] | None = None, + dtypes: Mapping[str, Any] | None = None, + **kwargs: Any, +) -> DataFrame: + """Load a single CSV table from a Track D data directory. + + This is a thin wrapper around :func:`pystatsv1.trackd.csvio.read_csv_required` + that resolves the data directory first. + """ + root = resolve_datadir(datadir) + path = root / filename + return read_csv_required( + path, + required_cols=required_cols, + parse_dates=parse_dates, + dtypes=dtypes, + **kwargs, + ) + + +def load_tables( + datadir: PathLike | None, + spec: Mapping[str, Sequence[str] | Mapping[str, Any]], +) -> DataFrames: + """Load multiple tables using a small spec mapping. + + Parameters + ---------- + datadir: + Folder containing the CSV tables. + spec: + Mapping from *key* to either: + + - a sequence of required column names (filename defaults to the key), or + - a dict with optional fields: + - filename: override CSV filename (defaults to key) + - required_cols: list/tuple/set of required columns + - parse_dates: list/tuple/set of date columns to parse + - dtypes: dict of dtypes to pass to pandas + - kwargs: dict of additional pandas.read_csv kwargs + + Returns + ------- + dict[str, pandas.DataFrame] + Loaded tables keyed by the spec keys. + """ + out: DataFrames = {} + for key, cfg in spec.items(): + if isinstance(cfg, (list, tuple, set)): + out[key] = load_table(datadir, key, required_cols=list(cfg)) + continue + + filename = str(cfg.get("filename", key)) + + required_cols_raw = cfg.get("required_cols") + if required_cols_raw is None: + required_cols = None + elif isinstance(required_cols_raw, (list, tuple, set)): + required_cols = list(required_cols_raw) + else: + raise TrackDDataError( + f"Invalid load_tables spec for {key}: 'required_cols' must be a " + "list/tuple/set." + ) + + parse_dates_raw = cfg.get("parse_dates") + if parse_dates_raw is None: + parse_dates = None + elif isinstance(parse_dates_raw, (list, tuple, set)): + parse_dates = list(parse_dates_raw) + else: + raise TrackDDataError( + f"Invalid load_tables spec for {key}: 'parse_dates' must be a " + "list/tuple/set." + ) + + dtypes_raw = cfg.get("dtypes") + if dtypes_raw is None: + dtypes = None + elif isinstance(dtypes_raw, dict): + dtypes = dtypes_raw + else: + raise TrackDDataError( + f"Invalid load_tables spec for {key}: 'dtypes' must be a dict." + ) + + extra_kwargs = cfg.get("kwargs", {}) + if not isinstance(extra_kwargs, dict): + raise TrackDDataError( + f"Invalid load_tables spec for {key}: 'kwargs' must be a dict." + ) + + out[key] = load_table( + datadir, + filename, + required_cols=required_cols, + parse_dates=parse_dates, + dtypes=dtypes, + **extra_kwargs, + ) + return out diff --git a/tests/test_trackd_loaders_errors_are_friendly.py b/tests/test_trackd_loaders_errors_are_friendly.py new file mode 100644 index 0000000..dc8d9db --- /dev/null +++ b/tests/test_trackd_loaders_errors_are_friendly.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from pystatsv1.trackd import TrackDDataError, TrackDSchemaError +from pystatsv1.trackd.loaders import load_table, resolve_datadir + + +def test_resolve_datadir_missing_is_friendly(tmp_path: Path) -> None: + missing = tmp_path / "nope" + with pytest.raises(TrackDDataError) as excinfo: + resolve_datadir(missing) + + msg = str(excinfo.value) + assert "Data directory not found" in msg + assert "Hint:" in msg + + +def test_load_table_missing_csv_is_friendly(tmp_path: Path) -> None: + # datadir exists, but the CSV does not + with pytest.raises(TrackDDataError) as excinfo: + load_table(tmp_path, "missing.csv", required_cols=["a"]) + + msg = str(excinfo.value) + assert "Missing CSV file" in msg + assert "missing.csv" in msg + assert "Hint:" in msg + + +def test_load_table_missing_required_columns_is_friendly(tmp_path: Path) -> None: + p = tmp_path / "data.csv" + p.write_text("a,b\n1,2\n", encoding="utf-8") + + with pytest.raises(TrackDSchemaError) as excinfo: + load_table(tmp_path, "data.csv", required_cols=["a", "c"]) + + msg = str(excinfo.value) + assert "Missing required columns" in msg + assert "c" in msg + assert "Found columns" in msg + assert "Hint:" in msg