Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
name: ci

on:
workflow_dispatch:
push:
Expand All @@ -7,24 +8,33 @@ on:
branches: [ main ]

jobs:
test:
runs-on: windows-latest # matches your local environment
windows-py310:
runs-on: windows-latest
timeout-minutes: 20
env:
PYTHONIOENCODING: utf-8
defaults:
run:
shell: pwsh

steps:
- uses: actions/checkout@v4
- name: Checkout
uses: actions/checkout@v4

- uses: actions/setup-python@v5
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"
python-version: '3.10'
cache: pip

- name: Install make (Chocolatey)
run: choco install make -y

- name: Install deps
shell: bash
- name: Install dependencies
run: |
python -m pip install --upgrade pip
if [ -f requirements.txt ]; then pip install -r requirements.txt; else pip install numpy pandas statsmodels matplotlib scipy; fi
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
if (Test-Path requirements.txt) { pip install -r requirements.txt } else { pip install numpy pandas statsmodels matplotlib scipy }
if (Test-Path requirements-dev.txt) { pip install -r requirements-dev.txt } else { pip install ruff pytest }

- name: Lint
run: make lint
Expand All @@ -35,10 +45,12 @@ jobs:
- name: Tiny Chapter 13 smoke
run: make ch13-ci

- name: Upload plots (artifact)
- name: Upload artifacts (plots & data)
if: always()
uses: actions/upload-artifact@v4
with:
name: ch13-plots
path: outputs/
if-no-files-found: ignore
name: ch13-artifacts
if-no-files-found: ignore
path: |
data/synthetic/**
outputs/**
54 changes: 39 additions & 15 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,29 +1,53 @@
# Default target
.DEFAULT_GOAL := help

# Config
PYTHON := python
SEED ?= 123
OUT_SYN := data/synthetic
OUT_CH13 := outputs/ch13

.PHONY: help
help:
@echo "Available targets:"
@echo " ch13 - full Chapter 13 run (plots saved)"
@echo " ch13-ci - tiny smoke (fast) for CI"
@echo " ch13 - full Chapter 13 run (sim + analysis + plots)"
@echo " ch13-ci - tiny smoke run for CI (fast)"
@echo " lint - run ruff checks"
@echo " lint-fix - auto-fix with ruff"
@echo " test - run pytest"
@echo " clean - remove generated outputs"

# ---- Fast CI smoke (small n, deterministic) ----
.PHONY: ch13-ci
ch13-ci:
python scripts/sim_stroop.py --n-subjects 6 --n-trials 10 --seed 2025
python scripts/ch13_stroop_within.py --save-plots
python scripts/sim_fitness_2x2.py --n-per-group 40 --seed 2025
python scripts/ch13_fitness_mixed.py --save-plots
$(PYTHON) -m scripts.sim_stroop --n-subjects 6 --n-trials 10 --seed $(SEED) --outdir $(OUT_SYN)
$(PYTHON) -m scripts.ch13_stroop_within --data $(OUT_SYN)/psych_stroop_trials.csv --outdir $(OUT_CH13) --save-plots --seed $(SEED)
$(PYTHON) -m scripts.sim_fitness_2x2 --n-per-group 10 --seed $(SEED) --outdir $(OUT_SYN)
$(PYTHON) -m scripts.ch13_fitness_mixed --data $(OUT_SYN)/fitness_long.csv --outdir $(OUT_CH13) --save-plots --seed $(SEED)

# ---- Full Chapter 13 demo (default sizes) ----
.PHONY: ch13
ch13:
python scripts/sim_stroop.py
python scripts/ch13_stroop_within.py --save-plots
python scripts/sim_fitness_2x2.py
python scripts/ch13_fitness_mixed.py --save-plots
$(PYTHON) -m scripts.sim_stroop --seed $(SEED) --outdir $(OUT_SYN)
$(PYTHON) -m scripts.ch13_stroop_within --data $(OUT_SYN)/psych_stroop_trials.csv --outdir $(OUT_CH13) --save-plots --seed $(SEED)
$(PYTHON) -m scripts.sim_fitness_2x2 --seed $(SEED) --outdir $(OUT_SYN)
$(PYTHON) -m scripts.ch13_fitness_mixed --data $(OUT_SYN)/fitness_long.csv --outdir $(OUT_CH13) --save-plots --seed $(SEED)

.PHONY: lint format test
# ---- Quality gates ----
.PHONY: lint
lint:
ruff check tests/
format:
black .
ruff check .

.PHONY: lint-fix
lint-fix:
ruff check . --fix

.PHONY: test
test:
pytest -q || true
pytest -q

# ---- Utilities ----
.PHONY: clean
clean:
@echo "Removing generated outputs in $(OUT_SYN) and $(OUT_CH13)"
-@rm -rf $(OUT_SYN) $(OUT_CH13)
Empty file added scripts/__init__.py
Empty file.
31 changes: 31 additions & 0 deletions scripts/_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import argparse
import pathlib
import random

import numpy as np


def base_parser(description: str) -> argparse.ArgumentParser:
p = argparse.ArgumentParser(description=description)
p.add_argument(
"--outdir",
type=pathlib.Path,
default=pathlib.Path("outputs"),
help="Where to write outputs (plots, csv). Default: ./outputs",
)
p.add_argument("--seed", type=int, default=None, help="Random seed for reproducibility")
return p


def apply_seed(seed: int | None) -> None:
if seed is None:
return
np.random.seed(seed)
random.seed(seed)
try: # optional: seed torch if present
import torch # pragma: no cover
torch.manual_seed(seed)
except Exception:
pass
128 changes: 86 additions & 42 deletions scripts/ch13_fitness_mixed.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,130 @@
# SPDX-License-Identifier: MIT
import argparse, os, math
import numpy as np, pandas as pd
import statsmodels.formula.api as smf
import matplotlib; matplotlib.use("Agg")
"""
Chapter 13 — Mixed fitness study analysis:
- LMM with random intercepts
- Within-group paired tests
- Between-group Welch t + Hedges g

Usage:
python -m scripts.ch13_fitness_mixed \
--data data/synthetic/fitness_long.csv \
--outdir outputs/ch13 --save-plots --seed 123
"""

import argparse
from pathlib import Path

import matplotlib
matplotlib.use("Agg") # safe for CI/headless
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import statsmodels.api as sm
from scipy import stats


# --- make Windows console UTF-8 friendly ---
import sys
if os.name == "nt":
if sys.platform.startswith("win"):
try:
sys.stdout.reconfigure(encoding="utf-8")
except Exception:
pass

def hedges_g(a, b):

def hedges_g(a: pd.Series, b: pd.Series) -> float:
na, nb = len(a), len(b)
sa, sb = a.var(ddof=1), b.var(ddof=1)
s_p = np.sqrt(((na-1)*sa + (nb-1)*sb)/(na+nb-2))
d = (a.mean()-b.mean())/s_p
J = 1 - 3/(4*(na+nb)-9)
return d*J

def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data", default="data/synthetic/fitness_long.csv")
ap.add_argument("--save-plots", action="store_true")
s_p = np.sqrt(((na - 1) * sa + (nb - 1) * sb) / (na + nb - 2))
d = (a.mean() - b.mean()) / s_p
J = 1 - 3 / (4 * (na + nb) - 9)
return float(d * J)


def main() -> None:
ap = argparse.ArgumentParser(description="Analyze mixed-design fitness data")
ap.add_argument(
"--data",
type=Path,
default=Path("data/synthetic/fitness_long.csv"),
help="Path to long-format CSV.",
)
ap.add_argument("--save-plots", action="store_true", help="Save summary plot")
ap.add_argument(
"--outdir",
type=Path,
default=Path("outputs"),
help="Where to write plots (if --save-plots).",
)
ap.add_argument("--seed", type=int, default=None, help="Optional RNG seed")
args = ap.parse_args()

if args.seed is not None:
np.random.seed(args.seed) # kept for CLI consistency; not required by current code

df = pd.read_csv(args.data)
df["time"] = pd.Categorical(df["time"], categories=["pre","post"], ordered=True)
df["time"] = pd.Categorical(df["time"], categories=["pre", "post"], ordered=True)
df["group"] = pd.Categorical(df["group"])

# Mixed model with random intercept per subject
import statsmodels.api as sm
md = sm.MixedLM.from_formula("strength ~ time*group + age + sex + bmi",
groups="id", re_formula="1", data=df)
md = sm.MixedLM.from_formula(
"strength ~ time*group + age + sex + bmi",
groups="id",
re_formula="1",
data=df,
)
m = md.fit(method="lbfgs")

print("=== LMM: strength ~ time*group + age + sex + bmi + (1|id) ===")
print(m.summary())

# Within-group pre→post paired tests
from scipy import stats
print("\n=== Within-group pre→post (paired) ===")
for g in df["group"].cat.categories:
wide = (df[df.group==g]
.pivot_table(index="id", columns="time", values="strength", observed=False))
for gname in df["group"].cat.categories:
wide = (
df[df.group == gname]
.pivot_table(index="id", columns="time", values="strength", observed=False)
)
pre, post = wide["pre"], wide["post"]
diff = (post - pre).dropna()
mean = diff.mean(); sd = diff.std(ddof=1); n = diff.shape[0]
t = mean / (sd/np.sqrt(n))
p = 2*stats.t.sf(abs(t), df=n-1)
mean = diff.mean()
sd = diff.std(ddof=1)
n = diff.shape[0]
t = mean / (sd / np.sqrt(n))
p = 2 * stats.t.sf(abs(t), df=n - 1)
d_paired = mean / sd
print(f"{g:8s} n={n:2d} Δ={mean:6.2f} t={t:6.2f} p={p:.3g} d_paired={d_paired:.2f}")
print(f"{gname:8s} n={n:2d} Δ={mean:6.2f} t={t:6.2f} p={p:.3g} d_paired={d_paired:.2f}")

# Between-group at post (Welch t) + Hedges g
post = df[df.time=="post"]
post = df[df.time == "post"]
cats = post.group.cat.categories
gA = post[post.group==cats[0]]["strength"]
gB = post[post.group==cats[1]]["strength"]
gA = post[post.group == cats[0]]["strength"]
gB = post[post.group == cats[1]]["strength"]
t2, p2 = stats.ttest_ind(gA, gB, equal_var=False)
g = hedges_g(gA, gB)
print("\n=== Between groups at post (Welch t) ===")
print(f"t={t2:.2f} p={p2:.3g} Hedges g={g:.2f}")

if args.save_plots:
os.makedirs("outputs", exist_ok=True)
fig, ax = plt.subplots(figsize=(6,4))
# spaghetti
for (grp, sid), sub in df.sort_values("time").groupby(["group","id"], observed=False):
ax.plot([0,1], sub["strength"].values, alpha=0.2)
args.outdir.mkdir(parents=True, exist_ok=True)
fig, ax = plt.subplots(figsize=(6, 4))
# spaghetti per subject
for (_, sid), sub in df.sort_values("time").groupby(["group", "id"], observed=False):
ax.plot([0, 1], sub["strength"].values, alpha=0.2)
# thick means per group
mean_pre = df[df.time=="pre" ].groupby("group", observed=False)["strength"].mean()
mean_post = df[df.time=="post"].groupby("group", observed=False)["strength"].mean()
mean_pre = df[df.time == "pre"].groupby("group", observed=False)["strength"].mean()
mean_post = df[df.time == "post"].groupby("group", observed=False)["strength"].mean()
for gname in df["group"].cat.categories:
ax.plot([0,1], [mean_pre[gname], mean_post[gname]], linewidth=3, label=gname)
ax.set_xticks([0,1]); ax.set_xticklabels(["pre","post"])
ax.plot([0, 1], [mean_pre[gname], mean_post[gname]], linewidth=3, label=gname)
ax.set_xticks([0, 1])
ax.set_xticklabels(["pre", "post"])
ax.set_ylabel("Strength")
ax.legend()
fig.tight_layout()
fp = "outputs/ch13_fitness_spaghetti.png"; fig.savefig(fp, dpi=150)
print("Saved plot ->", fp)
out_path = args.outdir / "ch13_fitness_spaghetti.png"
fig.savefig(out_path, dpi=150)
print(f"Saved plot -> {out_path}")


if __name__ == "__main__":
main()
main()
Loading
Loading