Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ tmd_files: tmd/storage/output/tmd.csv.gz \

.PHONY=test
test: tmd_files
pytest . -v -n4 --ignore=tests/national_targets_pipeline
pytest . -v -n4 --ignore=tests/national_targets_pipeline --ignore=tests/test_fingerprint.py

.PHONY=data
data: install tmd_files test
Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
np.seterr(all="raise")


def pytest_addoption(parser):
parser.addoption(
"--update-fingerprint",
action="store_true",
default=False,
help="Save current results as reference fingerprint",
)


def create_tmd_records(
data_path, weights_path, growfactors_path, exact_calculations=True
):
Expand Down
202 changes: 202 additions & 0 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""
On-demand fingerprint test for area weight results.

NOT run by `make test` (excluded in Makefile).
Run manually after a full solve:

pytest tests/test_fingerprint.py -v
pytest tests/test_fingerprint.py -v --update-fingerprint

The first run saves a reference fingerprint.
Subsequent runs compare against it.

Fingerprint method: for each area, round weights to nearest integer,
sum them, and hash the per-area sums. This is simple, fast, and
catches any meaningful change in results.
"""

import hashlib
import json

import numpy as np
import pandas as pd
import pytest

from tmd.areas import AREAS_FOLDER

FINGERPRINT_DIR = AREAS_FOLDER / "fingerprints"
STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states"

ALL_STATES = [
"AL",
"AK",
"AZ",
"AR",
"CA",
"CO",
"CT",
"DC",
"DE",
"FL",
"GA",
"HI",
"ID",
"IL",
"IN",
"IA",
"KS",
"KY",
"LA",
"ME",
"MD",
"MA",
"MI",
"MN",
"MS",
"MO",
"MT",
"NE",
"NV",
"NH",
"NJ",
"NM",
"NY",
"NC",
"ND",
"OH",
"OK",
"OR",
"PA",
"RI",
"SC",
"SD",
"TN",
"TX",
"UT",
"VT",
"VA",
"WA",
"WV",
"WI",
"WY",
]


def _compute_fingerprint(areas, weight_dir):
"""Compute fingerprint from weight files.

For each area, reads the first WT column, rounds weights to
nearest integer, and records the sum. The collection of integer
sums is hashed for a single comparison value.
"""
per_area = {}
for area in areas:
code = area.lower()
wpath = weight_dir / f"{code}_tmd_weights.csv.gz"
if not wpath.exists():
continue
wdf = pd.read_csv(wpath)
wt_cols = [c for c in wdf.columns if c.startswith("WT")]
wt = wdf[wt_cols[0]].values
int_sum = int(np.round(wt).sum())
per_area[area] = int_sum

# Hash of all per-area integer sums
hash_str = "|".join(f"{a}:{per_area[a]}" for a in sorted(per_area.keys()))
hash_val = hashlib.sha256(hash_str.encode()).hexdigest()[:16]

return {
"n_areas": len(per_area),
"weight_hash": hash_val,
"per_area_int_sums": per_area,
}


def _fingerprint_path(scope):
return FINGERPRINT_DIR / f"{scope}_fingerprint.json"


def _save_fingerprint(scope, fp):
FINGERPRINT_DIR.mkdir(parents=True, exist_ok=True)
path = _fingerprint_path(scope)
with open(path, "w", encoding="utf-8") as f:
json.dump(fp, f, indent=2, sort_keys=True)
return path


def _load_fingerprint(scope):
path = _fingerprint_path(scope)
if not path.exists():
return None
with open(path, "r", encoding="utf-8") as f:
return json.load(f)


@pytest.fixture
def update_mode(request):
return request.config.getoption("--update-fingerprint")


def _has_weight_files(weight_dir, areas):
for a in areas:
wpath = weight_dir / f"{a.lower()}_tmd_weights.csv.gz"
if wpath.exists():
return True
return False


@pytest.mark.skipif(
not _has_weight_files(STATE_WEIGHT_DIR, ALL_STATES),
reason="No state weight files — run solve_weights first",
)
class TestStateFingerprint:
"""Fingerprint tests for state weights."""

# pylint: disable=redefined-outer-name
def test_state_weights_match_reference(self, update_mode):
"""Compare weight integer sums against saved reference."""
current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)

if update_mode:
path = _save_fingerprint("states", current)
pytest.skip(f"Saved to {path} — re-run to test")

reference = _load_fingerprint("states")
if reference is None:
path = _save_fingerprint("states", current)
pytest.skip(f"No reference found. Saved to {path} — re-run")

ref_n = reference["n_areas"]
cur_n = current["n_areas"]
assert cur_n == ref_n, f"Area count: {ref_n} -> {cur_n}"

assert (
current["weight_hash"] == reference["weight_hash"]
), "Weight hash mismatch — results changed"

def test_per_area_sums_match(self, update_mode):
"""Identify which areas changed."""
if update_mode:
pytest.skip("Update mode")

reference = _load_fingerprint("states")
if reference is None:
pytest.skip("No reference fingerprint")

current = _compute_fingerprint(ALL_STATES, STATE_WEIGHT_DIR)
ref_sums = reference.get("per_area_int_sums", {})
cur_sums = current.get("per_area_int_sums", {})

mismatches = []
for area in sorted(ref_sums.keys()):
if area not in cur_sums:
mismatches.append(f"{area}: missing")
continue
if ref_sums[area] != cur_sums[area]:
mismatches.append(
f"{area}: {ref_sums[area]}" f" -> {cur_sums[area]}"
)

assert (
not mismatches
), f"{len(mismatches)} areas changed:\n" + "\n".join(mismatches)
163 changes: 163 additions & 0 deletions tests/test_state_weight_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
Post-solve validation of state weight files.

These tests verify that the actual state weight outputs are valid.
They are skipped if weight files have not been generated yet
(i.e., solve_weights --scope states has not been run).

Run after:
python -m tmd.areas.prepare_targets --scope states
python -m tmd.areas.solve_weights --scope states --workers 8
"""

import io
from pathlib import Path

import numpy as np
import pandas as pd
import pytest

from tmd.areas.create_area_weights import (
AREA_CONSTRAINT_TOL,
STATE_TARGET_DIR,
STATE_WEIGHT_DIR,
_build_constraint_matrix,
_drop_impossible_targets,
_load_taxcalc_data,
)
from tmd.areas.prepare.constants import ALL_STATES
from tmd.imputation_assumptions import TAXYEAR

# Skip entire module if weight files haven't been generated
_WEIGHT_FILES = list(STATE_WEIGHT_DIR.glob("*_tmd_weights.csv.gz"))
pytestmark = pytest.mark.skipif(
len(_WEIGHT_FILES) < 51,
reason="State weight files not generated yet",
)

# Also need cached data files for target accuracy checks
_CACHED = Path(__file__).parent.parent / "tmd" / "storage" / "output"
_HAS_CACHED = (_CACHED / "tmd.csv.gz").exists() and (
_CACHED / "cached_c00100.npy"
).exists()


class TestStateWeightFiles:
"""Basic validity checks on all 51 state weight files."""

def test_all_states_have_weight_files(self):
"""Every state has a weight file."""
for st in ALL_STATES:
wpath = STATE_WEIGHT_DIR / f"{st.lower()}_tmd_weights.csv.gz"
assert wpath.exists(), f"Missing weight file for {st}"

def test_all_states_have_log_files(self):
"""Every state has a solver log."""
for st in ALL_STATES:
logpath = STATE_WEIGHT_DIR / f"{st.lower()}.log"
assert logpath.exists(), f"Missing log file for {st}"

def test_weight_columns(self):
"""Weight files have expected year columns."""
wpath = STATE_WEIGHT_DIR / "mn_tmd_weights.csv.gz"
wdf = pd.read_csv(wpath)
expected = [f"WT{yr}" for yr in range(TAXYEAR, 2035)]
assert list(wdf.columns) == expected

def test_weight_row_count(self):
"""Weight files have one row per TMD record."""
wpath = STATE_WEIGHT_DIR / "mn_tmd_weights.csv.gz"
wdf = pd.read_csv(wpath)
# Should match TMD record count (215,494 for 2022)
assert len(wdf) > 200_000

@pytest.mark.parametrize(
"state",
[s.lower() for s in ALL_STATES],
)
def test_weights_nonnegative(self, state):
"""All weights are non-negative."""
wpath = STATE_WEIGHT_DIR / f"{state}_tmd_weights.csv.gz"
wdf = pd.read_csv(wpath)
assert (wdf >= 0).all().all(), f"{state}: negative weights found"

@pytest.mark.parametrize(
"state",
[s.lower() for s in ALL_STATES],
)
def test_weights_no_nan(self, state):
"""No NaN or inf values in weights."""
wpath = STATE_WEIGHT_DIR / f"{state}_tmd_weights.csv.gz"
wdf = pd.read_csv(wpath)
assert not wdf.isna().any().any(), f"{state}: NaN values found"
assert np.isfinite(wdf.values).all(), f"{state}: inf values found"

@pytest.mark.parametrize(
"state",
[s.lower() for s in ALL_STATES],
)
def test_solver_status_solved(self, state):
"""Solver log reports Solved status."""
logpath = STATE_WEIGHT_DIR / f"{state}.log"
log_text = logpath.read_text()
assert (
"Solver status: Solved" in log_text
), f"{state}: solver did not report Solved"


@pytest.mark.skipif(
not _HAS_CACHED,
reason="Cached TMD data files not available",
)
class TestStateTargetAccuracy:
"""Verify weighted sums hit targets within tolerance."""

@pytest.fixture(scope="class")
def vdf(self):
"""Load TMD data once for all accuracy tests."""
return _load_taxcalc_data()

@pytest.mark.parametrize(
"state",
["al", "ca", "mn", "ny", "tx"],
)
def test_targets_hit(self, vdf, state):
"""Weighted sums match targets within constraint tolerance."""
out = io.StringIO()
B_csc, targets, labels, pop_share = _build_constraint_matrix(
state,
vdf,
out,
target_dir=STATE_TARGET_DIR,
)
B_csc, targets, labels = _drop_impossible_targets(
B_csc,
targets,
labels,
out,
)

# Load weights and compute multipliers
wpath = STATE_WEIGHT_DIR / f"{state}_tmd_weights.csv.gz"
wdf = pd.read_csv(wpath)
area_weights = wdf[f"WT{TAXYEAR}"].values
w0 = pop_share * vdf["s006"].values
# Avoid division by zero for zero-weight records
safe_w0 = np.where(w0 > 0, w0, 1.0)
x = area_weights / safe_w0
x = np.where(w0 > 0, x, 0.0)

# Check target accuracy
achieved = np.asarray(B_csc @ x).ravel()
rel_errors = np.abs(achieved - targets) / np.maximum(
np.abs(targets), 1.0
)
# Allow small margin above solver tolerance for floating-point
# differences between solver internals and weight-file roundtrip
eps = 1e-4
n_violated = int((rel_errors > AREA_CONSTRAINT_TOL + eps).sum())
max_err = rel_errors.max()
assert n_violated == 0, (
f"{state}: {n_violated} targets violated, "
f"max error = {max_err * 100:.3f}%"
)
Loading
Loading