diff --git a/src/pystatsv1/trackd/__init__.py b/src/pystatsv1/trackd/__init__.py index b0fe20e..4bb3226 100644 --- a/src/pystatsv1/trackd/__init__.py +++ b/src/pystatsv1/trackd/__init__.py @@ -9,6 +9,7 @@ from ._errors import TrackDDataError, TrackDSchemaError # noqa: F401 from ._types import DataFrame, DataFrames, PathLike # noqa: F401 +from .csvio import read_csv_required # noqa: F401 __all__ = [ "DataFrame", @@ -16,4 +17,5 @@ "PathLike", "TrackDDataError", "TrackDSchemaError", + "read_csv_required", ] diff --git a/src/pystatsv1/trackd/csvio.py b/src/pystatsv1/trackd/csvio.py new file mode 100644 index 0000000..97717ff --- /dev/null +++ b/src/pystatsv1/trackd/csvio.py @@ -0,0 +1,87 @@ +"""Track D CSV I/O helpers. + +This module centralizes the most common, student-facing CSV loading patterns +used across Track D. + +Goals +- Friendly, consistent error messages. +- Small surface area (easy to reuse in chapter runners and BYOD adapters). +- Avoid leaking low-level pandas exceptions to beginners. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Mapping, Sequence + +import pandas as pd + +from ._errors import TrackDDataError, TrackDSchemaError +from ._types import DataFrame, PathLike + + +def read_csv_required( + path: PathLike, + *, + required_cols: Sequence[str] | None = None, + parse_dates: Sequence[str] | None = None, + dtypes: Mapping[str, Any] | None = None, + **kwargs: Any, +) -> DataFrame: + """Read a CSV file and enforce required columns. + + Parameters + ---------- + path: + Path to the CSV file. + required_cols: + Column names that must be present in the CSV header row. + parse_dates: + Column names to parse as dates (passed to pandas). + dtypes: + Optional dtype mapping (passed to pandas). + + Raises + ------ + TrackDDataError: + If the file is missing or can't be read. + TrackDSchemaError: + If required columns are missing. + + Returns + ------- + pandas.DataFrame + """ + p = Path(path) + if not p.exists(): + raise TrackDDataError( + f"Missing CSV file: {p}.\n" + "Hint: check your export location and filename, then try again." + ) + + read_kwargs: dict[str, Any] = dict(kwargs) + if parse_dates is not None: + read_kwargs["parse_dates"] = list(parse_dates) + if dtypes is not None: + read_kwargs["dtype"] = dtypes + + try: + df = pd.read_csv(p, **read_kwargs) + except Exception as e: # pragma: no cover + raise TrackDDataError( + f"Could not read CSV: {p}.\n" + f"Reason: {type(e).__name__}: {e}" + ) from e + + if required_cols: + missing = [c for c in required_cols if c not in df.columns] + if missing: + found = ", ".join(map(str, df.columns)) if len(df.columns) else "(no columns)" + req = ", ".join(missing) + raise TrackDSchemaError( + f"Missing required columns in {p.name}: {req}.\n" + f"Found columns: {found}.\n" + "Hint: ensure the first row is the header, and re-export as CSV if needed." + ) + + return df diff --git a/tests/test_trackd_csvio_errors_are_friendly.py b/tests/test_trackd_csvio_errors_are_friendly.py new file mode 100644 index 0000000..4b4e28b --- /dev/null +++ b/tests/test_trackd_csvio_errors_are_friendly.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from pystatsv1.trackd import TrackDDataError, TrackDSchemaError +from pystatsv1.trackd.csvio import read_csv_required + + +def test_read_csv_required_missing_file_is_friendly(tmp_path: Path) -> None: + missing = tmp_path / "nope.csv" + with pytest.raises(TrackDDataError) as excinfo: + read_csv_required(missing, required_cols=["a"]) + + msg = str(excinfo.value) + assert "Missing CSV file" in msg + assert "nope.csv" in msg + assert "Hint:" in msg + + +def test_read_csv_required_missing_columns_is_friendly(tmp_path: Path) -> None: + p = tmp_path / "data.csv" + p.write_text("a,b\n1,2\n", encoding="utf-8") + + with pytest.raises(TrackDSchemaError) as excinfo: + read_csv_required(p, required_cols=["a", "c"]) + + msg = str(excinfo.value) + assert "Missing required columns" in msg + assert "c" in msg + assert "Found columns" in msg + assert "Hint:" in msg + + +def test_read_csv_required_success(tmp_path: Path) -> None: + p = tmp_path / "ok.csv" + p.write_text("a,b\n1,2\n", encoding="utf-8") + + df = read_csv_required(p, required_cols=["a", "b"]) + assert list(df.columns) == ["a", "b"] + assert df.shape == (1, 2)