From 2216f85651dfd48829896ffaa4356922d571a24d Mon Sep 17 00:00:00 2001 From: Nicholas Karlson Date: Mon, 19 Jan 2026 12:18:42 -0800 Subject: [PATCH] Refactor: move Track D reporting style + mpl compat into pystatsv1.trackd --- ...business_ch09_reporting_style_contract.rst | 6 +- scripts/_mpl_compat.py | 33 +- scripts/_reporting_style.py | 506 +----------------- src/pystatsv1/assets/workbook_track_d.zip | Bin 166929 -> 162838 bytes src/pystatsv1/trackd/mpl_compat.py | 34 ++ src/pystatsv1/trackd/reporting_style.py | 506 ++++++++++++++++++ tests/test_trackd_reporting_style_smoke.py | 28 + .../track_d_template/scripts/_mpl_compat.py | 34 +- .../scripts/_reporting_style.py | 506 +----------------- 9 files changed, 594 insertions(+), 1059 deletions(-) create mode 100644 src/pystatsv1/trackd/mpl_compat.py create mode 100644 src/pystatsv1/trackd/reporting_style.py create mode 100644 tests/test_trackd_reporting_style_smoke.py diff --git a/docs/source/business_ch09_reporting_style_contract.rst b/docs/source/business_ch09_reporting_style_contract.rst index 1580d26..a91f57f 100644 --- a/docs/source/business_ch09_reporting_style_contract.rst +++ b/docs/source/business_ch09_reporting_style_contract.rst @@ -30,7 +30,7 @@ This chapter implements two things: 1) A reusable style/guardrails module: -- ``scripts/_reporting_style.py`` +- ``pystatsv1.trackd.reporting_style`` (source: ``src/pystatsv1/trackd/reporting_style.py``) 2) A Chapter 9 driver that produces a small, compliant “chart pack” + manifest: @@ -39,7 +39,9 @@ This chapter implements two things: The style contract (rules) -------------------------- -The style contract lives in ``scripts/_reporting_style.py`` as ``STYLE_CONTRACT``. +The style contract lives in ``pystatsv1.trackd.reporting_style`` as ``STYLE_CONTRACT``. + +Note: ``scripts/_reporting_style.py`` is kept as a small shim for backward compatibility with older Track D scripts and workbook templates. It is intentionally conservative so later chapters can reuse it. Allowed chart types diff --git a/scripts/_mpl_compat.py b/scripts/_mpl_compat.py index 48687ff..f0542e0 100644 --- a/scripts/_mpl_compat.py +++ b/scripts/_mpl_compat.py @@ -1,34 +1,9 @@ -"""Matplotlib compatibility helpers for workbook scripts. +"""Backward-compatible shim for Track D matplotlib helpers. -Matplotlib 3.9 renamed the Axes.boxplot keyword argument "labels" to -"tick_labels". The old name is deprecated and scheduled for removal. - -These helpers keep our educational scripts working on Matplotlib 3.8+ -while avoiding deprecation warnings on newer versions. +Historically, Track D chapter scripts imported :mod:`scripts._mpl_compat`. +The implementation now lives in :mod:`pystatsv1.trackd.mpl_compat`. """ from __future__ import annotations -from typing import Any, Sequence - - -def ax_boxplot( - ax: Any, - *args: Any, - tick_labels: Sequence[str] | None = None, - **kwargs: Any, -): - """Call ``ax.boxplot`` with a 3.8/3.9+ compatible keyword. - - Prefer ``tick_labels`` (Matplotlib >= 3.9). If that keyword is not - supported (Matplotlib <= 3.8), fall back to the legacy ``labels``. - """ - - if tick_labels is None: - return ax.boxplot(*args, **kwargs) - - try: - return ax.boxplot(*args, tick_labels=tick_labels, **kwargs) - except TypeError: - # Older Matplotlib: the new keyword doesn't exist. - return ax.boxplot(*args, labels=tick_labels, **kwargs) +from pystatsv1.trackd.mpl_compat import * # noqa: F401,F403 diff --git a/scripts/_reporting_style.py b/scripts/_reporting_style.py index 2338076..a229f35 100644 --- a/scripts/_reporting_style.py +++ b/scripts/_reporting_style.py @@ -1,506 +1,12 @@ -# SPDX-License-Identifier: MIT -"""Shared plotting/reporting helpers. +"""Backward-compatible shim for Track D reporting-style helpers. -Track D Chapter 9 introduces a *style contract* for figures and small reports. -This module centralizes the rules so later chapters can reuse them. +Historically, Track D chapter scripts imported :mod:`scripts._reporting_style`. +The implementation now lives in :mod:`pystatsv1.trackd.reporting_style`. -Design goals ------------- -- Matplotlib-only (no seaborn) -- Deterministic output filenames and metadata -- Guardrails against misleading axes (especially for bar charts) -- Simple defaults suitable for ReadTheDocs screenshots and printing - -The "style contract" is intentionally conservative; it favors clarity over -flash. Downstream chapters can extend it, but should keep the core rules. +Keeping this shim prevents template drift and avoids breaking older chapter +scripts that import from ``scripts/``. """ from __future__ import annotations -import json -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Iterable -from contextlib import contextmanager -import matplotlib as mpl -import numpy as np - -# Matplotlib is an optional dependency for some repo users. -# Track D chapters require it, so we import lazily in functions where possible. - -STYLE_CONTRACT: dict[str, Any] = { - "version": "1.0", - "allowed_chart_types": [ - "line", - "bar", - "histogram", - "ecdf", - "box", - "scatter", - "waterfall_bridge", - ], - "labeling_rules": { - "title_required": True, - "axis_labels_required": True, - "units_in_labels": True, - "use_month_tick_labels": "YYYY-MM", - "legend_only_if_multiple_series": True, - "caption_required_in_manifest": True, - }, - "anti_misleading_axes": { - "bar_charts_start_at_zero": True, - "explicit_note_if_y_truncated": True, - "show_zero_line_for_ratios": True, - "avoid_dual_axes": True, - }, - "distribution_guidance": { - "for_skewed_distributions": [ - "histogram + vertical lines for mean and median", - "ECDF (or quantile plot) to reveal tails", - "report key quantiles (p50, p75, p90, p95 if available)", - ] - }, - "file_format": {"type": "png", "dpi": 150}, - "figure_sizes": { - "time_series": [10.0, 4.0], - "distribution": [7.5, 4.5], - }, -} - - -# Minimal matplotlib rcParams for a consistent, non-misleading reporting look. -# NOTE: We intentionally avoid specifying colors so matplotlib defaults apply. -_REPORTING_RC: dict[str, object] = { - "figure.dpi": 120, - "savefig.dpi": 150, - "savefig.bbox": "tight", - "axes.grid": True, - "axes.titleweight": "bold", - "axes.titlesize": 12, - "axes.labelsize": 10, - "xtick.labelsize": 9, - "ytick.labelsize": 9, - "legend.fontsize": 9, -} - - -@contextmanager -def style_context(): - """Context manager to apply the reporting style contract to matplotlib figures.""" - with mpl.rc_context(rc=_REPORTING_RC): - yield - - - -@dataclass(frozen=True) -class FigureSpec: - """Minimal spec used when saving figures (validation + metadata).""" - - chart_type: str - title: str - caption: str = "" - x_label: str = "" - y_label: str = "" - data_source: str = "" - notes: str = "" - - -@dataclass(frozen=True) -class FigureManifestRow: - """One row in the Chapter 9 figure manifest CSV.""" - - filename: str - chart_type: str - title: str - x_label: str - y_label: str - guardrail_note: str - data_source: str - - - -def write_style_contract_json(outpath: Path) -> None: - """Write the global style contract to a JSON file.""" - - outpath.write_text(json.dumps(STYLE_CONTRACT, indent=2), encoding="utf-8") - - -def write_contract_json(outpath: Path) -> None: - """Write the global style contract to a JSON file.""" - outpath.write_text(json.dumps(STYLE_CONTRACT, indent=2), encoding="utf-8") - - - -def _mpl(): - """Import matplotlib with a non-interactive backend.""" - - import matplotlib - - # Ensure headless operation for CI / tests. - matplotlib.use("Agg", force=True) - - import matplotlib.pyplot as plt - - return matplotlib, plt - - -def mpl_context(): - """Context manager that applies a lightweight, consistent style.""" - - matplotlib, plt = _mpl() - - # A minimal rcParams set: keep things readable without over-styling. - rc = { - "figure.dpi": int(STYLE_CONTRACT["file_format"]["dpi"]), - "savefig.dpi": int(STYLE_CONTRACT["file_format"]["dpi"]), - "font.size": 10, - "axes.titlesize": 12, - "axes.labelsize": 10, - "legend.fontsize": 9, - "xtick.labelsize": 9, - "ytick.labelsize": 9, - "axes.grid": True, - "grid.alpha": 0.25, - "axes.spines.top": False, - "axes.spines.right": False, - } - - return matplotlib.rc_context(rc) - - -def save_figure(fig, outpath: Path, spec: FigureSpec | None = None) -> None: - """Save and close a Matplotlib figure deterministically. - - If spec is provided, enforce allowed chart types. - """ - if spec is not None: - ensure_allowed_chart_type(spec.chart_type) - - outpath.parent.mkdir(parents=True, exist_ok=True) - fig.tight_layout() - fig.savefig(outpath, bbox_inches="tight") - - # Avoid memory leaks in test runs. - _, plt = _mpl() - plt.close(fig) - - - -def _format_month_ticks(ax, months: list[str]) -> None: - """Format x-axis ticks for YYYY-MM month labels.""" - - # Show at most ~8 ticks; for longer series, reduce tick density. - n = len(months) - if n <= 8: - step = 1 - elif n <= 18: - step = 2 - else: - step = 3 - - ticks = list(range(0, n, step)) - ax.set_xticks(ticks) - ax.set_xticklabels([months[i] for i in ticks], rotation=45, ha="right") - - -def _enforce_bar_zero_baseline(ax) -> None: - """Enforce y-axis baseline at zero for bar charts.""" - - y0, y1 = ax.get_ylim() - if y0 > 0: - ax.set_ylim(0.0, y1) - elif y1 < 0: - ax.set_ylim(y0, 0.0) - - -def plot_time_series( - df, - x: str, - series: dict[str, str], - title: str, - x_label: str, - y_label: str, - figsize: tuple[float, float] | None = None, - show_zero_line: bool = False, -): - """Create a standard time-series line chart. - - Parameters - ---------- - df: - Dataframe with columns including x and all series columns. - x: - Column name for x-axis (typically month). - series: - Mapping of legend label -> column name. - show_zero_line: - If True, draw a horizontal line at y=0 (useful for ratios/growth). - """ - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] - figsize = (float(w), float(h)) - - fig, ax = plt.subplots(figsize=figsize) - - months = [str(m) for m in df[x].tolist()] - x_idx = np.arange(len(months)) - - for label, col in series.items(): - ax.plot(x_idx, df[col].astype(float).to_numpy(), marker="o", linewidth=1.5, label=label) - - if show_zero_line: - ax.axhline(0.0, linewidth=1.0) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - _format_month_ticks(ax, months) - - if len(series) > 1: - ax.legend(loc="best") - - return fig - - -def plot_bar( - df, - x: str, - y: str, - title: str, - x_label: str, - y_label: str, - figsize: tuple[float, float] | None = None, -): - """Create a standard bar chart with a zero baseline.""" - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] - figsize = (float(w), float(h)) - - fig, ax = plt.subplots(figsize=figsize) - - months = [str(m) for m in df[x].tolist()] - x_idx = np.arange(len(months)) - - ax.bar(x_idx, df[y].astype(float).to_numpy()) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - _format_month_ticks(ax, months) - _enforce_bar_zero_baseline(ax) - - return fig - - -def _ecdf(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - v = np.asarray(values, dtype=float) - v = v[~np.isnan(v)] - if v.size == 0: - return np.array([]), np.array([]) - v = np.sort(v) - y = np.arange(1, v.size + 1, dtype=float) / float(v.size) - return v, y - - -def plot_histogram_with_markers( - values: Iterable[float], - title: str, - x_label: str, - y_label: str, - markers: dict[str, float] | None = None, - figsize: tuple[float, float] | None = None, -): - """Histogram with optional vertical markers (e.g., mean/median).""" - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["distribution"] - figsize = (float(w), float(h)) - - v = np.asarray(list(values), dtype=float) - v = v[~np.isnan(v)] - - fig, ax = plt.subplots(figsize=figsize) - - if v.size > 0: - ax.hist(v, bins="auto") - - if markers: - for label, x0 in markers.items(): - if np.isfinite(x0): - ax.axvline(float(x0), linestyle="--", linewidth=1.2, label=label) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - if markers and len(markers) > 0: - ax.legend(loc="best") - - return fig - - -def plot_ecdf( - values: Iterable[float], - title: str, - x_label: str, - y_label: str, - markers: dict[str, float] | None = None, - figsize: tuple[float, float] | None = None, -): - """ECDF plot with optional vertical markers.""" - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["distribution"] - figsize = (float(w), float(h)) - - v = np.asarray(list(values), dtype=float) - x, y = _ecdf(v) - - fig, ax = plt.subplots(figsize=figsize) - - if x.size > 0: - ax.plot(x, y, marker=".", linestyle="none") - - if markers: - for label, x0 in markers.items(): - if np.isfinite(x0): - ax.axvline(float(x0), linestyle="--", linewidth=1.2, label=label) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - ax.set_ylim(0.0, 1.0) - - if markers and len(markers) > 0: - ax.legend(loc="best") - - return fig - - -def plot_waterfall_bridge( - start_label: str, - end_label: str, - start_value: float, - end_value: float, - components: list[tuple[str, float]], - title: str, - y_label: str, - x_label: str = "Component", - figsize: tuple[float, float] | None = None, -): - """Create a variance waterfall / bridge chart (start -> end via additive components). - - Guardrails - --------- - - Deterministic structure: explicit start and end totals plus additive components. - - Printer-safe encoding: hatch patterns distinguish positive vs negative deltas. - - Zero line included; y-limits padded to reduce truncation temptation. - - Notes - ----- - The caller is responsible for choosing defensible components. Any residual - can be included as an "Other / rounding" component to reconcile exactly. - """ - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] - figsize = (float(w), float(h)) - - labels = [start_label] + [name for name, _ in components] + [end_label] - - # Running totals after each component (for connectors and y-range). - running = float(start_value) - totals = [running] - for _, delta in components: - running += float(delta) - totals.append(running) - totals.append(float(end_value)) - - fig, ax = plt.subplots(figsize=figsize) - - # Start total - ax.bar(0, float(start_value), edgecolor="black", linewidth=0.8) - - # Component deltas - running = float(start_value) - for i, (_, delta) in enumerate(components, start=1): - d = float(delta) - new_total = running + d - - if d >= 0: - bottom = running - height = d - hatch = "//" - else: - bottom = new_total - height = -d - hatch = "\\" - - ax.bar(i, height, bottom=bottom, hatch=hatch, edgecolor="black", linewidth=0.8) - running = new_total - - # End total - ax.bar(len(labels) - 1, float(end_value), edgecolor="black", linewidth=0.8) - - # Connectors between bars (running totals) - running = float(start_value) - for i, (_, delta) in enumerate(components, start=1): - ax.plot([i - 0.4, i + 0.4], [running, running], linewidth=1.0) - running += float(delta) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - ax.set_xticks(list(range(len(labels)))) - ax.set_xticklabels(labels, rotation=0) - ax.axhline(0.0, linewidth=1.0) - - def _fmt(v: float) -> str: - return f"{v:,.0f}" - - # Annotate start/end totals - ax.text(0, float(start_value), _fmt(float(start_value)), ha="center", va="bottom") - ax.text(len(labels) - 1, float(end_value), _fmt(float(end_value)), ha="center", va="bottom") - - # Annotate component deltas - running = float(start_value) - for i, (_, delta) in enumerate(components, start=1): - d = float(delta) - y = (running + d) if d >= 0 else running - ax.text(i, y, f"{d:+,.0f}", ha="center", va="bottom") - running += d - - # Pad y-limits (anti-truncation guardrail) - lo = min([0.0] + totals) - hi = max([0.0] + totals) - span = hi - lo - pad = 0.10 * span if span > 0 else 1.0 - ax.set_ylim(lo - pad, hi + pad) - - return fig - - -def figure_manifest_to_frame(specs: list[FigureSpec]): - import pandas as pd - - return pd.DataFrame([asdict(s) for s in specs]) - - -def ensure_allowed_chart_type(chart_type: str) -> None: - allowed = set(STYLE_CONTRACT["allowed_chart_types"]) - if chart_type not in allowed: - raise ValueError(f"chart_type must be one of {sorted(allowed)}; got {chart_type!r}") +from pystatsv1.trackd.reporting_style import * # noqa: F401,F403 diff --git a/src/pystatsv1/assets/workbook_track_d.zip b/src/pystatsv1/assets/workbook_track_d.zip index fc1a5ea3c49d159956c460e29ae814145ed13c6e..1580ad5a5b4a7650df4f04f8cc4e90c6b13e4eb9 100644 GIT binary patch delta 2544 zcmY+F2~bm46o%isNMw_SB>}@?h#&}pkRqZ51}sRGDnSuz#}d*I1ObuFVKoMg$SNVq zb-|5^j!OmiswfIVHwBAQw~DA0tW}T-4vT1e^PcftGLy;7`TqZ$bN}3zsCF1t4I7@U zU#hP%tv_wbin2G_Zg@d#O%lW{QoY8}R-P~N55MNg?i_GiED4;)j%e&KbKr6vUGEH5 zuH3aWY2V4Z zbNgZhR@1u=9rdxuwyyd8<%K*mSISL-TaEFz5$|ueamf6?W&vHP#s;itt*P{?$T!Kjm z5QXVYup|h=0KFYV_T})AHn#idZ9~BIxxh}YtndoI7HFazddG>fu3!g8dS(W5jE&7X zi8+g^fBLOEuevj8WN0{kdV#WAbNh7H(B5N5bKdl34PK~E%w8CsV8f2xZG^@FUtfM%g`2tIXFuo?7w3R}!h0I-xHp zxkJ9_Veh1{_J(0;sH^oT&*<1&mDFqOrP+vq!qnw?W@tpmc~;(@uB5LXSyjw9+~ts^ zJlI&Z+dE8UTJmjJQEh{FQUh1$<{cMmhfGF;Wd#^ntb?n5I@N;i~zwZ9XmwKhVI$ zTMOXaJYa{^-{ev1)$>qM2v6i`!-VJP;`1d`3rnd$cix~BJkoh$il?MF=|(8>scvDZ z)QHRZU^>cUid;{bw&hT!iuJ$;nQrXIO|2D_X>`9GVaMgO!*+x zwYfQo#FzN}J=Jb(63N#oV(iB$B&P6y0&rouJ*m4QFjql!H+ImXS!j>Z`QtqAd}xnL zDCrIZ3#j~CWh4*A6i}yYTSoJDGSXW&wp>20#nu#Ba!F2l>LeZ;sj<<;TEv0HyGVB( zs-}^jhDPycNLL+d?-U6(7{Mvze2l}ZlZSUe@OG=n}8cS+5Jt_Oy5Cm)dHF; zfi0 zU-S$8Vp`)mEXQWLPyA$@GdWfjQEP1D;TVh2kLH(caE`^8Me`n8oMSQW_5eKCRs;;t zU8%OmMOc>obQ2LD=UA3cyvCjNsRO1HVwuJlQ-@hK73Wx{uf5SS4aL+l{f@W@3-&SH zBzGFlu|RE0s3+mMGgOrTE4ZQrn4!lZk*@yng|@2wC(B(2@E@qWIM`VNJXk650DWcu E1M=M_cmMzZ delta 6491 zcmY*d1yob<8s9D{Eh#Ysq@+iqbR*IYB1m^P=;%g}nxr7opmaAPEhPfdNH|FmkPvv2 zJ)gXtv$L~be)qfoot<-2k8iwqiIWt7L6u~J@vf3U4x^d?QX273ZZ|1Ic%INyUjm@Xm_2dub^sL5e?UkS=ah|1 zI&SX{`#UqMOsI7kS+*oHX%|#?4QjF>ScG${6bYzgcy8G#n;W{2DQz03GTz)#IU>id z{Z@eqllGw+v&^UL7ON44HepKjw#jVM>%d;iu4(QuE47w=2@vU93hR`gBr1};S2RYf z+r{^8RDeR_$vpc4hDvCb2yTz&t@XWCQ~#Cj}m=3!c5>1U!~v zqW3d>_geJr`Dihv(nlTz+R`K}4W%51hbT_)D^c-6R(&VVvUx_Nygnybk9|2#>*9nt z|F%Foal3p|2npzZCH5UBiLtm63|&|D$Uygra+PS$5c-w*I_5|t;Ea9Ddjo;>J9tF? z4$b;VaGP^Pq-pajwV94;b)0DFSo^7H0284`C(Zw^v7K;gc_}tE)omYqnNqPV8_hu6 zeXNqadnWVpg(K!|YFZ*!YMdP?3;zVHVIllsLnCzGSW?`{dlWq?zsQg_!?%x=R!+cm zrq-;ksj)9nB~!JCG`K;ZASS^-e|g=PFQU1*pr&=tC;bAAH07S`CDf?ntA6g`9D z@jej<g)+A`ar0?YlV$LU>RihxgdcH}kJ$9Gb)^E~@gN6DINz~xxhCi5TdE+Rf{R5KPgp9HeD}uJ18#eUg^-J!x znrbn(Ag=kQ1v(nk^aD6{9~gN+{Cr#XzN|U22T4*fnXp26xbCFxXN2xkCe1P>C){yk z=F*yl7_-unPKG2uvgpDdPF>BrjXRtA+a1u(CCQ;&GVJPpV%I$F{`IJsfiJ9#897eMn!{P417n>RVXOpnK0 z<`7dtMFG~zt#n>@uGB{O9+xBdzz$ncO+S`zx~|NU3$7mOXk6b0P@;28daP!6HWt>k z(BRPui_1kv6hx1+(uQ;9@}~5DIBwEsKk@`S=kNK_r61c07#FuYXxb!x!ZT^tiHl7O zfOjR@a*=0;+@+*)%igubnMg zu^6M)+M0WHFI9kiM{r&qEy))LaH zrQ<5v_$9@Fs?am^mJZ{NcnNz0HHTY7HV6!>A7W+bja*vMkbCj8w*(|)OUy)_Ca@3T zW(+#hzFQHN+CKLtYRG)juAMaBhQcq!!wS5}SLk zsp#oU&%{O-+PsP2o5z9dat++y0dSL!sSe53Z2~@Dyq68mSPU4+@x4K;u)|xI0ZNOo zwg8UtummF)I`w@BS9P{D>yHFoumqSn}_Y(gz{Q*nQA9C7Av%J zmJEj~lLX3ZZJcm$xqshZLB}nJHN;)38RQ?&uWjnJj62PzI@&N>apxnt&@FS0PMP#TS>vj9r+tLzIPJme&iCd%#f(wN zwATl#ah{EE-5aE(e+<$0an)vXS6BY9+Tjb}+;&u{E|blLl>6^x!u*a+=v6l`$W;alpyWjy_`oYx!C?V&Vjtxv&(ojv%l zl!P>elbj}A93QG zcb-pK#l9C4&IP#miQDyG!rsKhxf+no@@w1(!tE`rI8#ZIN8c6dwo)ajowr%VRA4>T ziX|EQQ81YYee!E|x7+6~d#Gtco6`6Ip9H@Q2732V#q((D7#jZQxr`XGH|1jbKE~CD z7nn_GQa$u__CFpVipzodO3=o8? zJQQhG-|jc2I~Ix##ly`V@b+>T-VoK?lCPm87fpUKxnqov`-CAUqzG$;xLGVHXcOSQ z+{WeGXQxdk6A5}XV~Xygb??#~5XO{zw~%qTrcY&uJ)(#Z68zXRrS(J3+Q8J}#(=w1 z$KyrK%-cz#Rg8uZ&vTfY>EKPn#bmM-9bz}X>C;tk(uq8JBQCgb>|@c3OR*guoF8xF zwx$rVdf8QEw(s#a>#&R_Mk?J6`w*$6sa0LYJvfhmF_O@Ov>XbUXMUxWWsd0Xbzfw_M?B-#?(5FKNcE&1B z*wE>y)y6Kw=;8YoR$e0g22Z#9nszEKh4z8BqqU_5sVRS)aCClKkcO~W1xpP}WSLv% zMXJWAY0Pj%PLslvr)Qi#i!4iJI{Ls{%xm(KT<@nb-6-cWQl*&VDq^#SRY+}*QL=Vd zL|I9AkS#uc79|~4z)s%v<{`=9YLsiQaEr1nCn2b%Q#E-n5UY}{Q^>hmeumKZ zVf{#+5gW#DEo-g3C22TjxGNUg@xi(L#a*(4i4IO_{MBAqScE%y@Ibd1IIT8>`mNS- zWxG#(+aZnf+Sza&7Nz!%q>Or*ko&&J_DrtUnQavQy}(qOS);S{iYQ42gGakoeHVDL z7ejvD=|o>!@bWvEu^<_U8z%i6+LBi9`^y+)<=Q`IFXE~a51YRQt3-mU)$LTHTi9*` zvH=0xPjc;H2J=OB_c+v{NtI3$Zyu|RAm$>hGxyuG@KsMh%&E;}Xp}OMgZ2&)MypH5 zN5_le@AtK{@qay@TJYGuqnCd6(}aZ=G4X8>YbWcMg`oJZrJN|_Bf?{HSI2;-4A(L6 z%S1w*cN@$OyO}WOOTef@x3Cs}@txnW_8ox532O}B>R-FTJFIT8|nF5|jDp`2;J&4=Dr`ihecBGfW_ z797dI0^=REiFqmSJuL7#Y6O;JRyj9M{LQ0rf1+Z!e)}Bjj+8&s%tDwxWoTiS<*mK2 zzO+WX9};4!2Rl^HXw+#M?qD@R6E0t z+J4l>JD0wA@d6X}h}??fEFS+so9?Q4?^(!xJ1gi#C-yr>;?ukRw59X&kbr_OJ?O@F zsR{fSUdaeueg)5L+6kT*8($!#5BDe|ZW#RpE%O|G3w!@0BiCXi{nqU*n|PzIX2xC9 zzVkG&XK5WCUB~GBG-#d6j`MQee#h0Mn??rt2USuvIMP{W+7|ZS@o?2O%I*I8Vp3iu z%{F;NDq_|=V&|se<#gxodp={~WR7Nj@d7u8wFvX91;vMTDW;nyvijr~1?+8R+KP|e ztlz9j*TUsUDib|hUKbjC8?2>EmO5jeWcSO5S-fG-7@Tl^@xxCcYHcSY%6=u5?&mQt z#9&eq4$_PZ_ZsRXbUOhiL{=Pu~WWhPpFk+=GENR z@UiXv{VW@4|ZoJNkZfr0St2 z1E5?!%Ebe>Vd3AeEoXz%<2vKD*l+t{b=iA-OClya>hX{(ZBfk%3Q9pVCNYq>O6MhT zuV4h{OtvXva)JT6n~Mit@drZO-<$NkL;SMX+dSp0$DOT0fs$8Q__NP1Fr?!>*=Ofr zLkcVnmk70P)c|{5Rwv7M!T4vzC)h(3cQ4PN%%TU{dJpUutnrhZ#bx^EL-~#_Y7E2= zMsKf4IQo3ld0MKi8lg6=s@n?sb$R32q;~MsU0neJ2|`(%U!8D+(Qw#kAJ6@J*`lA; zpUlCU{DvM2i|A=R5Bq(et8fdFeB!c`0sFv-Ak6MgAM6%*;AJHK14{HgI4>lvXYbsL zIP_~z#1YP3ly7r+?9Go&ez|Gc+^^LnNk;tUmNQCu{kGLc0yStm4~4nHpT%!0_`Kr_ zLu>w};?huYyogP1ySOUi^UYhhFVYj;KdwM?fw=Qo|7tFDi4bwzahW8~?QAiwV+x~f zwSkCO>Ycd_^7P|i$7)VAhyT(z^`hRIRc@8hvmFUuoIC{NTTJ3?Q!sZ{cXhW z>LZ*6`!zkXg#PjSM08rL$42{m);c=c3ufP94V^tyYPlUfqQiXDBC6Sba8O)!>zUnX z+ikwtUL@Wzc0myGlR4riGr0aIGY}*N1|E$5h2@Y{9(HhLZS1T$%wm=rEy z2jC(rlub5O|I~fPoyVz#d>g&blKAUv>bOlg|`I^oR@SD7nje_Ue7 zlG~0~fS5BX#>f$1L&m7NQo(Z^uL|UYQ1wq;8UNNl@w$E+^}|tucrW(9Ee1|kq4Ciu z=7S&fuXQZpy0tF{CAdt0{I#+>Us<V4|2d@6yP$H^rA)VZ_MwB>oGsS!f(0& zqDZl-%awcQ94RJ;N4o%mNU#$H4&uXI-2igA|4?dws!>2R45dAH0|byLpZk@}ApH+Y z29J;amynC1glbTf;~OyB-*7ySE7$NR6s6z+a3iz4QRwCt3a#=0*pcW062+(k^PvN9 z49`Dz6wc#$h1B8!a6M0e7KwPFkQQP9z5oFKyUbQ3OjQRqqX0ne=5Dqw9`4*0JWwtd z-~axmPLMqER94)A-}VB?kU0WgR|X=g>ju7m23HjCN`0N@|Htp(0aOhC&g>e^b)CRJ zojn0?eg9^(a6<2^0hqk62I#nb-6-_W_yon1(O>6<|M6KUK7`>qFY=GyB>BJMl{@fD zRIw--Kn3Ua0kDz#;$ynzhF;s3hR@Y>$;{Vz-hb(FDE^G)I)80_4wV1bZhjZO>vMJ4 z7`|7*@od+P*Y+ia;v3kn^Me11pQ89>j_drjwdK(MUoo5uzUX^Ze26kC^ISJx+ZWXD zYOs0eHJ+E}U$G;KH{-j`U)xwQl1HAfO@8=@-&G-u8M)_6zpFia30`+y+Y`P2m1{)^ vuHX-_BG07@3K&&}r}|(0lLT4;RQUf6nzt7KPqP9bsQ~~HrkxM+?*sH7UP2La diff --git a/src/pystatsv1/trackd/mpl_compat.py b/src/pystatsv1/trackd/mpl_compat.py new file mode 100644 index 0000000..48687ff --- /dev/null +++ b/src/pystatsv1/trackd/mpl_compat.py @@ -0,0 +1,34 @@ +"""Matplotlib compatibility helpers for workbook scripts. + +Matplotlib 3.9 renamed the Axes.boxplot keyword argument "labels" to +"tick_labels". The old name is deprecated and scheduled for removal. + +These helpers keep our educational scripts working on Matplotlib 3.8+ +while avoiding deprecation warnings on newer versions. +""" + +from __future__ import annotations + +from typing import Any, Sequence + + +def ax_boxplot( + ax: Any, + *args: Any, + tick_labels: Sequence[str] | None = None, + **kwargs: Any, +): + """Call ``ax.boxplot`` with a 3.8/3.9+ compatible keyword. + + Prefer ``tick_labels`` (Matplotlib >= 3.9). If that keyword is not + supported (Matplotlib <= 3.8), fall back to the legacy ``labels``. + """ + + if tick_labels is None: + return ax.boxplot(*args, **kwargs) + + try: + return ax.boxplot(*args, tick_labels=tick_labels, **kwargs) + except TypeError: + # Older Matplotlib: the new keyword doesn't exist. + return ax.boxplot(*args, labels=tick_labels, **kwargs) diff --git a/src/pystatsv1/trackd/reporting_style.py b/src/pystatsv1/trackd/reporting_style.py new file mode 100644 index 0000000..2338076 --- /dev/null +++ b/src/pystatsv1/trackd/reporting_style.py @@ -0,0 +1,506 @@ +# SPDX-License-Identifier: MIT +"""Shared plotting/reporting helpers. + +Track D Chapter 9 introduces a *style contract* for figures and small reports. +This module centralizes the rules so later chapters can reuse them. + +Design goals +------------ +- Matplotlib-only (no seaborn) +- Deterministic output filenames and metadata +- Guardrails against misleading axes (especially for bar charts) +- Simple defaults suitable for ReadTheDocs screenshots and printing + +The "style contract" is intentionally conservative; it favors clarity over +flash. Downstream chapters can extend it, but should keep the core rules. +""" + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Iterable +from contextlib import contextmanager +import matplotlib as mpl +import numpy as np + +# Matplotlib is an optional dependency for some repo users. +# Track D chapters require it, so we import lazily in functions where possible. + +STYLE_CONTRACT: dict[str, Any] = { + "version": "1.0", + "allowed_chart_types": [ + "line", + "bar", + "histogram", + "ecdf", + "box", + "scatter", + "waterfall_bridge", + ], + "labeling_rules": { + "title_required": True, + "axis_labels_required": True, + "units_in_labels": True, + "use_month_tick_labels": "YYYY-MM", + "legend_only_if_multiple_series": True, + "caption_required_in_manifest": True, + }, + "anti_misleading_axes": { + "bar_charts_start_at_zero": True, + "explicit_note_if_y_truncated": True, + "show_zero_line_for_ratios": True, + "avoid_dual_axes": True, + }, + "distribution_guidance": { + "for_skewed_distributions": [ + "histogram + vertical lines for mean and median", + "ECDF (or quantile plot) to reveal tails", + "report key quantiles (p50, p75, p90, p95 if available)", + ] + }, + "file_format": {"type": "png", "dpi": 150}, + "figure_sizes": { + "time_series": [10.0, 4.0], + "distribution": [7.5, 4.5], + }, +} + + +# Minimal matplotlib rcParams for a consistent, non-misleading reporting look. +# NOTE: We intentionally avoid specifying colors so matplotlib defaults apply. +_REPORTING_RC: dict[str, object] = { + "figure.dpi": 120, + "savefig.dpi": 150, + "savefig.bbox": "tight", + "axes.grid": True, + "axes.titleweight": "bold", + "axes.titlesize": 12, + "axes.labelsize": 10, + "xtick.labelsize": 9, + "ytick.labelsize": 9, + "legend.fontsize": 9, +} + + +@contextmanager +def style_context(): + """Context manager to apply the reporting style contract to matplotlib figures.""" + with mpl.rc_context(rc=_REPORTING_RC): + yield + + + +@dataclass(frozen=True) +class FigureSpec: + """Minimal spec used when saving figures (validation + metadata).""" + + chart_type: str + title: str + caption: str = "" + x_label: str = "" + y_label: str = "" + data_source: str = "" + notes: str = "" + + +@dataclass(frozen=True) +class FigureManifestRow: + """One row in the Chapter 9 figure manifest CSV.""" + + filename: str + chart_type: str + title: str + x_label: str + y_label: str + guardrail_note: str + data_source: str + + + +def write_style_contract_json(outpath: Path) -> None: + """Write the global style contract to a JSON file.""" + + outpath.write_text(json.dumps(STYLE_CONTRACT, indent=2), encoding="utf-8") + + +def write_contract_json(outpath: Path) -> None: + """Write the global style contract to a JSON file.""" + outpath.write_text(json.dumps(STYLE_CONTRACT, indent=2), encoding="utf-8") + + + +def _mpl(): + """Import matplotlib with a non-interactive backend.""" + + import matplotlib + + # Ensure headless operation for CI / tests. + matplotlib.use("Agg", force=True) + + import matplotlib.pyplot as plt + + return matplotlib, plt + + +def mpl_context(): + """Context manager that applies a lightweight, consistent style.""" + + matplotlib, plt = _mpl() + + # A minimal rcParams set: keep things readable without over-styling. + rc = { + "figure.dpi": int(STYLE_CONTRACT["file_format"]["dpi"]), + "savefig.dpi": int(STYLE_CONTRACT["file_format"]["dpi"]), + "font.size": 10, + "axes.titlesize": 12, + "axes.labelsize": 10, + "legend.fontsize": 9, + "xtick.labelsize": 9, + "ytick.labelsize": 9, + "axes.grid": True, + "grid.alpha": 0.25, + "axes.spines.top": False, + "axes.spines.right": False, + } + + return matplotlib.rc_context(rc) + + +def save_figure(fig, outpath: Path, spec: FigureSpec | None = None) -> None: + """Save and close a Matplotlib figure deterministically. + + If spec is provided, enforce allowed chart types. + """ + if spec is not None: + ensure_allowed_chart_type(spec.chart_type) + + outpath.parent.mkdir(parents=True, exist_ok=True) + fig.tight_layout() + fig.savefig(outpath, bbox_inches="tight") + + # Avoid memory leaks in test runs. + _, plt = _mpl() + plt.close(fig) + + + +def _format_month_ticks(ax, months: list[str]) -> None: + """Format x-axis ticks for YYYY-MM month labels.""" + + # Show at most ~8 ticks; for longer series, reduce tick density. + n = len(months) + if n <= 8: + step = 1 + elif n <= 18: + step = 2 + else: + step = 3 + + ticks = list(range(0, n, step)) + ax.set_xticks(ticks) + ax.set_xticklabels([months[i] for i in ticks], rotation=45, ha="right") + + +def _enforce_bar_zero_baseline(ax) -> None: + """Enforce y-axis baseline at zero for bar charts.""" + + y0, y1 = ax.get_ylim() + if y0 > 0: + ax.set_ylim(0.0, y1) + elif y1 < 0: + ax.set_ylim(y0, 0.0) + + +def plot_time_series( + df, + x: str, + series: dict[str, str], + title: str, + x_label: str, + y_label: str, + figsize: tuple[float, float] | None = None, + show_zero_line: bool = False, +): + """Create a standard time-series line chart. + + Parameters + ---------- + df: + Dataframe with columns including x and all series columns. + x: + Column name for x-axis (typically month). + series: + Mapping of legend label -> column name. + show_zero_line: + If True, draw a horizontal line at y=0 (useful for ratios/growth). + """ + + _, plt = _mpl() + + if figsize is None: + w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] + figsize = (float(w), float(h)) + + fig, ax = plt.subplots(figsize=figsize) + + months = [str(m) for m in df[x].tolist()] + x_idx = np.arange(len(months)) + + for label, col in series.items(): + ax.plot(x_idx, df[col].astype(float).to_numpy(), marker="o", linewidth=1.5, label=label) + + if show_zero_line: + ax.axhline(0.0, linewidth=1.0) + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + _format_month_ticks(ax, months) + + if len(series) > 1: + ax.legend(loc="best") + + return fig + + +def plot_bar( + df, + x: str, + y: str, + title: str, + x_label: str, + y_label: str, + figsize: tuple[float, float] | None = None, +): + """Create a standard bar chart with a zero baseline.""" + + _, plt = _mpl() + + if figsize is None: + w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] + figsize = (float(w), float(h)) + + fig, ax = plt.subplots(figsize=figsize) + + months = [str(m) for m in df[x].tolist()] + x_idx = np.arange(len(months)) + + ax.bar(x_idx, df[y].astype(float).to_numpy()) + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + _format_month_ticks(ax, months) + _enforce_bar_zero_baseline(ax) + + return fig + + +def _ecdf(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + v = np.asarray(values, dtype=float) + v = v[~np.isnan(v)] + if v.size == 0: + return np.array([]), np.array([]) + v = np.sort(v) + y = np.arange(1, v.size + 1, dtype=float) / float(v.size) + return v, y + + +def plot_histogram_with_markers( + values: Iterable[float], + title: str, + x_label: str, + y_label: str, + markers: dict[str, float] | None = None, + figsize: tuple[float, float] | None = None, +): + """Histogram with optional vertical markers (e.g., mean/median).""" + + _, plt = _mpl() + + if figsize is None: + w, h = STYLE_CONTRACT["figure_sizes"]["distribution"] + figsize = (float(w), float(h)) + + v = np.asarray(list(values), dtype=float) + v = v[~np.isnan(v)] + + fig, ax = plt.subplots(figsize=figsize) + + if v.size > 0: + ax.hist(v, bins="auto") + + if markers: + for label, x0 in markers.items(): + if np.isfinite(x0): + ax.axvline(float(x0), linestyle="--", linewidth=1.2, label=label) + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + + if markers and len(markers) > 0: + ax.legend(loc="best") + + return fig + + +def plot_ecdf( + values: Iterable[float], + title: str, + x_label: str, + y_label: str, + markers: dict[str, float] | None = None, + figsize: tuple[float, float] | None = None, +): + """ECDF plot with optional vertical markers.""" + + _, plt = _mpl() + + if figsize is None: + w, h = STYLE_CONTRACT["figure_sizes"]["distribution"] + figsize = (float(w), float(h)) + + v = np.asarray(list(values), dtype=float) + x, y = _ecdf(v) + + fig, ax = plt.subplots(figsize=figsize) + + if x.size > 0: + ax.plot(x, y, marker=".", linestyle="none") + + if markers: + for label, x0 in markers.items(): + if np.isfinite(x0): + ax.axvline(float(x0), linestyle="--", linewidth=1.2, label=label) + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_ylim(0.0, 1.0) + + if markers and len(markers) > 0: + ax.legend(loc="best") + + return fig + + +def plot_waterfall_bridge( + start_label: str, + end_label: str, + start_value: float, + end_value: float, + components: list[tuple[str, float]], + title: str, + y_label: str, + x_label: str = "Component", + figsize: tuple[float, float] | None = None, +): + """Create a variance waterfall / bridge chart (start -> end via additive components). + + Guardrails + --------- + - Deterministic structure: explicit start and end totals plus additive components. + - Printer-safe encoding: hatch patterns distinguish positive vs negative deltas. + - Zero line included; y-limits padded to reduce truncation temptation. + + Notes + ----- + The caller is responsible for choosing defensible components. Any residual + can be included as an "Other / rounding" component to reconcile exactly. + """ + + _, plt = _mpl() + + if figsize is None: + w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] + figsize = (float(w), float(h)) + + labels = [start_label] + [name for name, _ in components] + [end_label] + + # Running totals after each component (for connectors and y-range). + running = float(start_value) + totals = [running] + for _, delta in components: + running += float(delta) + totals.append(running) + totals.append(float(end_value)) + + fig, ax = plt.subplots(figsize=figsize) + + # Start total + ax.bar(0, float(start_value), edgecolor="black", linewidth=0.8) + + # Component deltas + running = float(start_value) + for i, (_, delta) in enumerate(components, start=1): + d = float(delta) + new_total = running + d + + if d >= 0: + bottom = running + height = d + hatch = "//" + else: + bottom = new_total + height = -d + hatch = "\\" + + ax.bar(i, height, bottom=bottom, hatch=hatch, edgecolor="black", linewidth=0.8) + running = new_total + + # End total + ax.bar(len(labels) - 1, float(end_value), edgecolor="black", linewidth=0.8) + + # Connectors between bars (running totals) + running = float(start_value) + for i, (_, delta) in enumerate(components, start=1): + ax.plot([i - 0.4, i + 0.4], [running, running], linewidth=1.0) + running += float(delta) + + ax.set_title(title) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_xticks(list(range(len(labels)))) + ax.set_xticklabels(labels, rotation=0) + ax.axhline(0.0, linewidth=1.0) + + def _fmt(v: float) -> str: + return f"{v:,.0f}" + + # Annotate start/end totals + ax.text(0, float(start_value), _fmt(float(start_value)), ha="center", va="bottom") + ax.text(len(labels) - 1, float(end_value), _fmt(float(end_value)), ha="center", va="bottom") + + # Annotate component deltas + running = float(start_value) + for i, (_, delta) in enumerate(components, start=1): + d = float(delta) + y = (running + d) if d >= 0 else running + ax.text(i, y, f"{d:+,.0f}", ha="center", va="bottom") + running += d + + # Pad y-limits (anti-truncation guardrail) + lo = min([0.0] + totals) + hi = max([0.0] + totals) + span = hi - lo + pad = 0.10 * span if span > 0 else 1.0 + ax.set_ylim(lo - pad, hi + pad) + + return fig + + +def figure_manifest_to_frame(specs: list[FigureSpec]): + import pandas as pd + + return pd.DataFrame([asdict(s) for s in specs]) + + +def ensure_allowed_chart_type(chart_type: str) -> None: + allowed = set(STYLE_CONTRACT["allowed_chart_types"]) + if chart_type not in allowed: + raise ValueError(f"chart_type must be one of {sorted(allowed)}; got {chart_type!r}") diff --git a/tests/test_trackd_reporting_style_smoke.py b/tests/test_trackd_reporting_style_smoke.py new file mode 100644 index 0000000..f12b3bc --- /dev/null +++ b/tests/test_trackd_reporting_style_smoke.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import json + +import matplotlib as mpl +import matplotlib.pyplot as plt + +from pystatsv1.trackd import reporting_style +from pystatsv1.trackd.mpl_compat import ax_boxplot + + +def test_reporting_style_contract_writer(tmp_path) -> None: + out = tmp_path / "style_contract.json" + reporting_style.write_contract_json(out) + + loaded = json.loads(out.read_text(encoding="utf-8")) + assert loaded == reporting_style.STYLE_CONTRACT + + +def test_ax_boxplot_smoke() -> None: + # Ensure a non-interactive backend (CI / headless). + mpl.use("Agg", force=True) + + fig, ax = plt.subplots() + try: + ax_boxplot(ax, [1, 2, 3, 4, 5], tick_labels=["x"]) + finally: + plt.close(fig) diff --git a/workbooks/track_d_template/scripts/_mpl_compat.py b/workbooks/track_d_template/scripts/_mpl_compat.py index 48687ff..d135899 100644 --- a/workbooks/track_d_template/scripts/_mpl_compat.py +++ b/workbooks/track_d_template/scripts/_mpl_compat.py @@ -1,34 +1,12 @@ -"""Matplotlib compatibility helpers for workbook scripts. +"""Backward-compatible shim for Track D matplotlib helpers. -Matplotlib 3.9 renamed the Axes.boxplot keyword argument "labels" to -"tick_labels". The old name is deprecated and scheduled for removal. +The Track D workbook template historically shipped an implementation of this +module in the workbook itself. The canonical implementation now lives in the +installed PyStatsV1 package at :mod:`pystatsv1.trackd.mpl_compat`. -These helpers keep our educational scripts working on Matplotlib 3.8+ -while avoiding deprecation warnings on newer versions. +Keeping this shim avoids breaking imports inside the workbook's chapter scripts. """ from __future__ import annotations -from typing import Any, Sequence - - -def ax_boxplot( - ax: Any, - *args: Any, - tick_labels: Sequence[str] | None = None, - **kwargs: Any, -): - """Call ``ax.boxplot`` with a 3.8/3.9+ compatible keyword. - - Prefer ``tick_labels`` (Matplotlib >= 3.9). If that keyword is not - supported (Matplotlib <= 3.8), fall back to the legacy ``labels``. - """ - - if tick_labels is None: - return ax.boxplot(*args, **kwargs) - - try: - return ax.boxplot(*args, tick_labels=tick_labels, **kwargs) - except TypeError: - # Older Matplotlib: the new keyword doesn't exist. - return ax.boxplot(*args, labels=tick_labels, **kwargs) +from pystatsv1.trackd.mpl_compat import * # noqa: F401,F403 diff --git a/workbooks/track_d_template/scripts/_reporting_style.py b/workbooks/track_d_template/scripts/_reporting_style.py index 2338076..d58d60e 100644 --- a/workbooks/track_d_template/scripts/_reporting_style.py +++ b/workbooks/track_d_template/scripts/_reporting_style.py @@ -1,506 +1,12 @@ -# SPDX-License-Identifier: MIT -"""Shared plotting/reporting helpers. +"""Backward-compatible shim for Track D reporting style. -Track D Chapter 9 introduces a *style contract* for figures and small reports. -This module centralizes the rules so later chapters can reuse them. +The Track D workbook template historically shipped an implementation of this +module in the workbook itself. The canonical implementation now lives in the +installed PyStatsV1 package at :mod:`pystatsv1.trackd.reporting_style`. -Design goals ------------- -- Matplotlib-only (no seaborn) -- Deterministic output filenames and metadata -- Guardrails against misleading axes (especially for bar charts) -- Simple defaults suitable for ReadTheDocs screenshots and printing - -The "style contract" is intentionally conservative; it favors clarity over -flash. Downstream chapters can extend it, but should keep the core rules. +Keeping this shim avoids breaking imports inside the workbook's chapter scripts. """ from __future__ import annotations -import json -from dataclasses import asdict, dataclass -from pathlib import Path -from typing import Any, Iterable -from contextlib import contextmanager -import matplotlib as mpl -import numpy as np - -# Matplotlib is an optional dependency for some repo users. -# Track D chapters require it, so we import lazily in functions where possible. - -STYLE_CONTRACT: dict[str, Any] = { - "version": "1.0", - "allowed_chart_types": [ - "line", - "bar", - "histogram", - "ecdf", - "box", - "scatter", - "waterfall_bridge", - ], - "labeling_rules": { - "title_required": True, - "axis_labels_required": True, - "units_in_labels": True, - "use_month_tick_labels": "YYYY-MM", - "legend_only_if_multiple_series": True, - "caption_required_in_manifest": True, - }, - "anti_misleading_axes": { - "bar_charts_start_at_zero": True, - "explicit_note_if_y_truncated": True, - "show_zero_line_for_ratios": True, - "avoid_dual_axes": True, - }, - "distribution_guidance": { - "for_skewed_distributions": [ - "histogram + vertical lines for mean and median", - "ECDF (or quantile plot) to reveal tails", - "report key quantiles (p50, p75, p90, p95 if available)", - ] - }, - "file_format": {"type": "png", "dpi": 150}, - "figure_sizes": { - "time_series": [10.0, 4.0], - "distribution": [7.5, 4.5], - }, -} - - -# Minimal matplotlib rcParams for a consistent, non-misleading reporting look. -# NOTE: We intentionally avoid specifying colors so matplotlib defaults apply. -_REPORTING_RC: dict[str, object] = { - "figure.dpi": 120, - "savefig.dpi": 150, - "savefig.bbox": "tight", - "axes.grid": True, - "axes.titleweight": "bold", - "axes.titlesize": 12, - "axes.labelsize": 10, - "xtick.labelsize": 9, - "ytick.labelsize": 9, - "legend.fontsize": 9, -} - - -@contextmanager -def style_context(): - """Context manager to apply the reporting style contract to matplotlib figures.""" - with mpl.rc_context(rc=_REPORTING_RC): - yield - - - -@dataclass(frozen=True) -class FigureSpec: - """Minimal spec used when saving figures (validation + metadata).""" - - chart_type: str - title: str - caption: str = "" - x_label: str = "" - y_label: str = "" - data_source: str = "" - notes: str = "" - - -@dataclass(frozen=True) -class FigureManifestRow: - """One row in the Chapter 9 figure manifest CSV.""" - - filename: str - chart_type: str - title: str - x_label: str - y_label: str - guardrail_note: str - data_source: str - - - -def write_style_contract_json(outpath: Path) -> None: - """Write the global style contract to a JSON file.""" - - outpath.write_text(json.dumps(STYLE_CONTRACT, indent=2), encoding="utf-8") - - -def write_contract_json(outpath: Path) -> None: - """Write the global style contract to a JSON file.""" - outpath.write_text(json.dumps(STYLE_CONTRACT, indent=2), encoding="utf-8") - - - -def _mpl(): - """Import matplotlib with a non-interactive backend.""" - - import matplotlib - - # Ensure headless operation for CI / tests. - matplotlib.use("Agg", force=True) - - import matplotlib.pyplot as plt - - return matplotlib, plt - - -def mpl_context(): - """Context manager that applies a lightweight, consistent style.""" - - matplotlib, plt = _mpl() - - # A minimal rcParams set: keep things readable without over-styling. - rc = { - "figure.dpi": int(STYLE_CONTRACT["file_format"]["dpi"]), - "savefig.dpi": int(STYLE_CONTRACT["file_format"]["dpi"]), - "font.size": 10, - "axes.titlesize": 12, - "axes.labelsize": 10, - "legend.fontsize": 9, - "xtick.labelsize": 9, - "ytick.labelsize": 9, - "axes.grid": True, - "grid.alpha": 0.25, - "axes.spines.top": False, - "axes.spines.right": False, - } - - return matplotlib.rc_context(rc) - - -def save_figure(fig, outpath: Path, spec: FigureSpec | None = None) -> None: - """Save and close a Matplotlib figure deterministically. - - If spec is provided, enforce allowed chart types. - """ - if spec is not None: - ensure_allowed_chart_type(spec.chart_type) - - outpath.parent.mkdir(parents=True, exist_ok=True) - fig.tight_layout() - fig.savefig(outpath, bbox_inches="tight") - - # Avoid memory leaks in test runs. - _, plt = _mpl() - plt.close(fig) - - - -def _format_month_ticks(ax, months: list[str]) -> None: - """Format x-axis ticks for YYYY-MM month labels.""" - - # Show at most ~8 ticks; for longer series, reduce tick density. - n = len(months) - if n <= 8: - step = 1 - elif n <= 18: - step = 2 - else: - step = 3 - - ticks = list(range(0, n, step)) - ax.set_xticks(ticks) - ax.set_xticklabels([months[i] for i in ticks], rotation=45, ha="right") - - -def _enforce_bar_zero_baseline(ax) -> None: - """Enforce y-axis baseline at zero for bar charts.""" - - y0, y1 = ax.get_ylim() - if y0 > 0: - ax.set_ylim(0.0, y1) - elif y1 < 0: - ax.set_ylim(y0, 0.0) - - -def plot_time_series( - df, - x: str, - series: dict[str, str], - title: str, - x_label: str, - y_label: str, - figsize: tuple[float, float] | None = None, - show_zero_line: bool = False, -): - """Create a standard time-series line chart. - - Parameters - ---------- - df: - Dataframe with columns including x and all series columns. - x: - Column name for x-axis (typically month). - series: - Mapping of legend label -> column name. - show_zero_line: - If True, draw a horizontal line at y=0 (useful for ratios/growth). - """ - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] - figsize = (float(w), float(h)) - - fig, ax = plt.subplots(figsize=figsize) - - months = [str(m) for m in df[x].tolist()] - x_idx = np.arange(len(months)) - - for label, col in series.items(): - ax.plot(x_idx, df[col].astype(float).to_numpy(), marker="o", linewidth=1.5, label=label) - - if show_zero_line: - ax.axhline(0.0, linewidth=1.0) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - _format_month_ticks(ax, months) - - if len(series) > 1: - ax.legend(loc="best") - - return fig - - -def plot_bar( - df, - x: str, - y: str, - title: str, - x_label: str, - y_label: str, - figsize: tuple[float, float] | None = None, -): - """Create a standard bar chart with a zero baseline.""" - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] - figsize = (float(w), float(h)) - - fig, ax = plt.subplots(figsize=figsize) - - months = [str(m) for m in df[x].tolist()] - x_idx = np.arange(len(months)) - - ax.bar(x_idx, df[y].astype(float).to_numpy()) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - _format_month_ticks(ax, months) - _enforce_bar_zero_baseline(ax) - - return fig - - -def _ecdf(values: np.ndarray) -> tuple[np.ndarray, np.ndarray]: - v = np.asarray(values, dtype=float) - v = v[~np.isnan(v)] - if v.size == 0: - return np.array([]), np.array([]) - v = np.sort(v) - y = np.arange(1, v.size + 1, dtype=float) / float(v.size) - return v, y - - -def plot_histogram_with_markers( - values: Iterable[float], - title: str, - x_label: str, - y_label: str, - markers: dict[str, float] | None = None, - figsize: tuple[float, float] | None = None, -): - """Histogram with optional vertical markers (e.g., mean/median).""" - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["distribution"] - figsize = (float(w), float(h)) - - v = np.asarray(list(values), dtype=float) - v = v[~np.isnan(v)] - - fig, ax = plt.subplots(figsize=figsize) - - if v.size > 0: - ax.hist(v, bins="auto") - - if markers: - for label, x0 in markers.items(): - if np.isfinite(x0): - ax.axvline(float(x0), linestyle="--", linewidth=1.2, label=label) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - - if markers and len(markers) > 0: - ax.legend(loc="best") - - return fig - - -def plot_ecdf( - values: Iterable[float], - title: str, - x_label: str, - y_label: str, - markers: dict[str, float] | None = None, - figsize: tuple[float, float] | None = None, -): - """ECDF plot with optional vertical markers.""" - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["distribution"] - figsize = (float(w), float(h)) - - v = np.asarray(list(values), dtype=float) - x, y = _ecdf(v) - - fig, ax = plt.subplots(figsize=figsize) - - if x.size > 0: - ax.plot(x, y, marker=".", linestyle="none") - - if markers: - for label, x0 in markers.items(): - if np.isfinite(x0): - ax.axvline(float(x0), linestyle="--", linewidth=1.2, label=label) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - ax.set_ylim(0.0, 1.0) - - if markers and len(markers) > 0: - ax.legend(loc="best") - - return fig - - -def plot_waterfall_bridge( - start_label: str, - end_label: str, - start_value: float, - end_value: float, - components: list[tuple[str, float]], - title: str, - y_label: str, - x_label: str = "Component", - figsize: tuple[float, float] | None = None, -): - """Create a variance waterfall / bridge chart (start -> end via additive components). - - Guardrails - --------- - - Deterministic structure: explicit start and end totals plus additive components. - - Printer-safe encoding: hatch patterns distinguish positive vs negative deltas. - - Zero line included; y-limits padded to reduce truncation temptation. - - Notes - ----- - The caller is responsible for choosing defensible components. Any residual - can be included as an "Other / rounding" component to reconcile exactly. - """ - - _, plt = _mpl() - - if figsize is None: - w, h = STYLE_CONTRACT["figure_sizes"]["time_series"] - figsize = (float(w), float(h)) - - labels = [start_label] + [name for name, _ in components] + [end_label] - - # Running totals after each component (for connectors and y-range). - running = float(start_value) - totals = [running] - for _, delta in components: - running += float(delta) - totals.append(running) - totals.append(float(end_value)) - - fig, ax = plt.subplots(figsize=figsize) - - # Start total - ax.bar(0, float(start_value), edgecolor="black", linewidth=0.8) - - # Component deltas - running = float(start_value) - for i, (_, delta) in enumerate(components, start=1): - d = float(delta) - new_total = running + d - - if d >= 0: - bottom = running - height = d - hatch = "//" - else: - bottom = new_total - height = -d - hatch = "\\" - - ax.bar(i, height, bottom=bottom, hatch=hatch, edgecolor="black", linewidth=0.8) - running = new_total - - # End total - ax.bar(len(labels) - 1, float(end_value), edgecolor="black", linewidth=0.8) - - # Connectors between bars (running totals) - running = float(start_value) - for i, (_, delta) in enumerate(components, start=1): - ax.plot([i - 0.4, i + 0.4], [running, running], linewidth=1.0) - running += float(delta) - - ax.set_title(title) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - ax.set_xticks(list(range(len(labels)))) - ax.set_xticklabels(labels, rotation=0) - ax.axhline(0.0, linewidth=1.0) - - def _fmt(v: float) -> str: - return f"{v:,.0f}" - - # Annotate start/end totals - ax.text(0, float(start_value), _fmt(float(start_value)), ha="center", va="bottom") - ax.text(len(labels) - 1, float(end_value), _fmt(float(end_value)), ha="center", va="bottom") - - # Annotate component deltas - running = float(start_value) - for i, (_, delta) in enumerate(components, start=1): - d = float(delta) - y = (running + d) if d >= 0 else running - ax.text(i, y, f"{d:+,.0f}", ha="center", va="bottom") - running += d - - # Pad y-limits (anti-truncation guardrail) - lo = min([0.0] + totals) - hi = max([0.0] + totals) - span = hi - lo - pad = 0.10 * span if span > 0 else 1.0 - ax.set_ylim(lo - pad, hi + pad) - - return fig - - -def figure_manifest_to_frame(specs: list[FigureSpec]): - import pandas as pd - - return pd.DataFrame([asdict(s) for s in specs]) - - -def ensure_allowed_chart_type(chart_type: str) -> None: - allowed = set(STYLE_CONTRACT["allowed_chart_types"]) - if chart_type not in allowed: - raise ValueError(f"chart_type must be one of {sorted(allowed)}; got {chart_type!r}") +from pystatsv1.trackd.reporting_style import * # noqa: F401,F403