Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pystatsv1/trackd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

from ._errors import TrackDDataError, TrackDSchemaError # noqa: F401
from ._types import DataFrame, DataFrames, PathLike # noqa: F401
from .csvio import read_csv_required # noqa: F401

__all__ = [
"DataFrame",
"DataFrames",
"PathLike",
"TrackDDataError",
"TrackDSchemaError",
"read_csv_required",
]
87 changes: 87 additions & 0 deletions src/pystatsv1/trackd/csvio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Track D CSV I/O helpers.

This module centralizes the most common, student-facing CSV loading patterns
used across Track D.

Goals
- Friendly, consistent error messages.
- Small surface area (easy to reuse in chapter runners and BYOD adapters).
- Avoid leaking low-level pandas exceptions to beginners.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any, Mapping, Sequence

import pandas as pd

from ._errors import TrackDDataError, TrackDSchemaError
from ._types import DataFrame, PathLike


def read_csv_required(
path: PathLike,
*,
required_cols: Sequence[str] | None = None,
parse_dates: Sequence[str] | None = None,
dtypes: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> DataFrame:
"""Read a CSV file and enforce required columns.

Parameters
----------
path:
Path to the CSV file.
required_cols:
Column names that must be present in the CSV header row.
parse_dates:
Column names to parse as dates (passed to pandas).
dtypes:
Optional dtype mapping (passed to pandas).

Raises
------
TrackDDataError:
If the file is missing or can't be read.
TrackDSchemaError:
If required columns are missing.

Returns
-------
pandas.DataFrame
"""
p = Path(path)
if not p.exists():
raise TrackDDataError(
f"Missing CSV file: {p}.\n"
"Hint: check your export location and filename, then try again."
)

read_kwargs: dict[str, Any] = dict(kwargs)
if parse_dates is not None:
read_kwargs["parse_dates"] = list(parse_dates)
if dtypes is not None:
read_kwargs["dtype"] = dtypes

try:
df = pd.read_csv(p, **read_kwargs)
except Exception as e: # pragma: no cover
raise TrackDDataError(
f"Could not read CSV: {p}.\n"
f"Reason: {type(e).__name__}: {e}"
) from e

if required_cols:
missing = [c for c in required_cols if c not in df.columns]
if missing:
found = ", ".join(map(str, df.columns)) if len(df.columns) else "(no columns)"
req = ", ".join(missing)
raise TrackDSchemaError(
f"Missing required columns in {p.name}: {req}.\n"
f"Found columns: {found}.\n"
"Hint: ensure the first row is the header, and re-export as CSV if needed."
)

return df
42 changes: 42 additions & 0 deletions tests/test_trackd_csvio_errors_are_friendly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

from pathlib import Path

import pytest

from pystatsv1.trackd import TrackDDataError, TrackDSchemaError
from pystatsv1.trackd.csvio import read_csv_required


def test_read_csv_required_missing_file_is_friendly(tmp_path: Path) -> None:
missing = tmp_path / "nope.csv"
with pytest.raises(TrackDDataError) as excinfo:
read_csv_required(missing, required_cols=["a"])

msg = str(excinfo.value)
assert "Missing CSV file" in msg
assert "nope.csv" in msg
assert "Hint:" in msg


def test_read_csv_required_missing_columns_is_friendly(tmp_path: Path) -> None:
p = tmp_path / "data.csv"
p.write_text("a,b\n1,2\n", encoding="utf-8")

with pytest.raises(TrackDSchemaError) as excinfo:
read_csv_required(p, required_cols=["a", "c"])

msg = str(excinfo.value)
assert "Missing required columns" in msg
assert "c" in msg
assert "Found columns" in msg
assert "Hint:" in msg


def test_read_csv_required_success(tmp_path: Path) -> None:
p = tmp_path / "ok.csv"
p.write_text("a,b\n1,2\n", encoding="utf-8")

df = read_csv_required(p, required_cols=["a", "b"])
assert list(df.columns) == ["a", "b"]
assert df.shape == (1, 2)