Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ tmd/storage/output/cached_files
tmd/storage/output/tax_expenditures
!tmd/storage/input/*.csv
!tmd/areas/targets/*.csv
!tmd/areas/prepare/data/soi_cds/*.csv
!tmd/areas/prepare/recipes/*.csv
!tmd/national_targets/data/*.csv
!tmd/national_targets/data/extracted/**/.gitkeep
**demographics_2015.csv
Expand Down
92 changes: 92 additions & 0 deletions tests/test_prepare_targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,95 @@ def test_prepare_mn_end_to_end(tmp_path):
rows = list(csv.DictReader(f))
assert len(rows) == result["MN"]
assert rows[0]["varname"] == "XTOT"


# ---- CD share and target tests ----


class TestCDShares:
"""Validate the pre-computed CD shares file."""

@pytest.fixture(scope="class")
def cd_shares(self):
"""Load the pre-computed CD shares CSV."""
path = _PREPARE / "data" / "cds_shares.csv"
if not path.exists():
pytest.skip("CD shares file not found (run prepare_shares first)")
return pd.read_csv(path)

def test_no_duplicate_cd_shares(self, cd_shares):
"""Each (area, var, count, fstatus, agistub) has one share."""
group_cols = [
"area",
"varname",
"count",
"fstatus",
"agistub",
]
counts = cd_shares.groupby(group_cols).size()
dupes = counts[counts > 1]
assert len(dupes) == 0, (
f"Found {len(dupes)} duplicate CD share groups. "
f"First few: {dupes.head(5).to_dict()}"
)

def test_cd_shares_sum_to_one(self, cd_shares):
"""Non-XTOT shares for 436 CDs sum to ~1.0 per var/bin."""
shared = cd_shares[cd_shares["varname"] != "XTOT"].copy()
shared = shared.dropna(subset=["soi_share"])
group_cols = ["varname", "count", "fstatus", "agistub"]
group_sums = shared.groupby(group_cols)["soi_share"].sum()
nonzero = group_sums[group_sums.abs() > 0.01]
np.testing.assert_allclose(
nonzero.values,
1.0,
atol=0.01,
err_msg="436-CD shares should sum to ~1.0",
)

def test_cd_count_is_436(self, cd_shares):
"""Should have exactly 436 congressional districts."""
n_areas = cd_shares["area"].nunique()
assert n_areas == 436, f"Expected 436 CDs, got {n_areas}"


class TestCDTargetFiles:
"""Validate CD target file structure (no solving)."""

@pytest.fixture(scope="class")
def cd_target_dir(self):
"""Path to CD target files."""
path = REPO_ROOT / "tmd" / "areas" / "targets" / "cds"
if not path.exists() or not list(path.glob("*_targets.csv")):
pytest.skip(
"CD target files not found"
" (run prepare_targets --scope cds first)"
)
return path

def test_cd_target_count(self, cd_target_dir):
"""Should have 436 CD target files."""
files = list(cd_target_dir.glob("*_targets.csv"))
assert (
len(files) == 436
), f"Expected 436 CD target files, got {len(files)}"

def test_cd_target_structure(self, cd_target_dir):
"""Spot-check AL01 target file for correct structure."""
fpath = cd_target_dir / "al01_targets.csv"
if not fpath.exists():
pytest.skip("al01_targets.csv not found")
rows = pd.read_csv(fpath, comment="#")
expected_cols = {
"varname",
"count",
"scope",
"agilo",
"agihi",
"fstatus",
"target",
}
assert set(rows.columns) == expected_cols
assert rows.iloc[0]["varname"] == "XTOT"
assert not rows["target"].isna().any()
assert len(rows) == 107
104 changes: 52 additions & 52 deletions tmd/areas/fingerprints/states_fingerprint.json
Original file line number Diff line number Diff line change
@@ -1,57 +1,57 @@
{
"n_areas": 51,
"per_area_int_sums": {
"AK": 412466,
"AL": 2749392,
"AR": 1622704,
"AZ": 4130134,
"CA": 22167973,
"CO": 3442035,
"CT": 2131586,
"DC": 433116,
"DE": 585680,
"FL": 13026186,
"GA": 6008652,
"HI": 832970,
"IA": 1769383,
"ID": 1031419,
"IL": 7256718,
"IN": 3843732,
"KS": 1620038,
"KY": 2468402,
"LA": 2496679,
"MA": 4216198,
"MD": 3605225,
"ME": 833930,
"MI": 5795022,
"MN": 3279543,
"MO": 3447807,
"MS": 1568159,
"MT": 647039,
"NC": 5920637,
"ND": 435558,
"NE": 1084195,
"NH": 830201,
"NJ": 5351788,
"NM": 1217003,
"NV": 1843143,
"NY": 11692200,
"OH": 6776517,
"OK": 2146996,
"OR": 2459157,
"PA": 7525412,
"RI": 666732,
"SC": 2959393,
"SD": 506402,
"TN": 3891676,
"TX": 16241382,
"UT": 1769984,
"VA": 4920809,
"VT": 393735,
"WA": 4447067,
"WI": 3419741,
"WV": 989190,
"WY": 320883
"AK": 415791,
"AL": 2762087,
"AR": 1631739,
"AZ": 4132505,
"CA": 22174018,
"CO": 3430800,
"CT": 2124185,
"DC": 433577,
"DE": 586366,
"FL": 13044856,
"GA": 6023316,
"HI": 833430,
"IA": 1766300,
"ID": 1029793,
"IL": 7239136,
"IN": 3846951,
"KS": 1617612,
"KY": 2475428,
"LA": 2513256,
"MA": 4201046,
"MD": 3598516,
"ME": 830704,
"MI": 5793917,
"MN": 3264136,
"MO": 3447588,
"MS": 1580717,
"MT": 646366,
"NC": 5922856,
"ND": 433592,
"NE": 1082430,
"NH": 824615,
"NJ": 5323243,
"NM": 1219684,
"NV": 1845416,
"NY": 11690631,
"OH": 6780588,
"OK": 2155817,
"OR": 2453765,
"PA": 7519481,
"RI": 665125,
"SC": 2959988,
"SD": 506399,
"TN": 3902420,
"TX": 16304386,
"UT": 1759690,
"VA": 4908473,
"VT": 392672,
"WA": 4424726,
"WI": 3410721,
"WV": 990364,
"WY": 319797
},
"weight_hash": "8b36ae1c2ee0c384"
"weight_hash": "5ce2e8e78e5faa4d"
}
66 changes: 61 additions & 5 deletions tmd/areas/prepare/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""
Shared constants for state target preparation.
Shared constants for area target preparation.

AGI cut points, SOI file patterns, variable mappings, and area
type definitions used across the preparation pipeline.

Note: CD (Congressional District) constants will be added in a
future PR.
type definitions used across the preparation pipeline for
states and congressional districts.
"""

from enum import Enum
Expand All @@ -19,6 +17,7 @@ class AreaType(Enum):
"""Type of sub-national area."""

STATE = "state"
CD = "cd"


# --- AGI range definitions ---
Expand All @@ -42,6 +41,24 @@ class AreaType(Enum):
# Number of non-total AGI stubs
STATE_NUM_AGI_STUBS = len(STATE_AGI_CUTS) - 1 # 10

# CD AGI stubs: 9 bins (IRS SOI Congressional District data)
# Same as state stubs 1-8, with stubs 9+10 merged into "$500K+"
# agistub 0 = total, agistubs 1-9 = bins
CD_AGI_CUTS: List[float] = [
-np.inf,
1,
10_000,
25_000,
50_000,
75_000,
100_000,
200_000,
500_000,
np.inf,
]

CD_NUM_AGI_STUBS = len(CD_AGI_CUTS) - 1 # 9


def build_agi_labels(area_type: AreaType) -> pd.DataFrame:
"""
Expand All @@ -52,6 +69,8 @@ def build_agi_labels(area_type: AreaType) -> pd.DataFrame:
"""
if area_type == AreaType.STATE:
cuts = STATE_AGI_CUTS
elif area_type == AreaType.CD:
cuts = CD_AGI_CUTS
else:
raise ValueError(f"Unsupported area_type: {area_type}")

Expand Down Expand Up @@ -99,6 +118,24 @@ def build_agi_labels(area_type: AreaType) -> pd.DataFrame:
2022: "22in55cmcsv.csv",
}

# CD SOI CSV files by year
SOI_CD_CSV_PATTERNS: Dict[int, str] = {
2021: "21incd.csv",
2022: "22incd.csv",
}

# At-large states: single CD coded as CONG_DISTRICT=0 in SOI data
AT_LARGE_STATES: List[str] = [
"AK",
"DC",
"DE",
"MT",
"ND",
"SD",
"VT",
"WY",
]


# --- Variable classifications ---

Expand Down Expand Up @@ -194,6 +231,25 @@ def build_agi_labels(area_type: AreaType) -> pd.DataFrame:
# All valid 2-letter state codes (uppercase)
ALL_STATES: List[str] = sorted(STATE_INFO.keys())


def get_agi_cuts(area_type: AreaType) -> List[float]:
"""Return AGI cut points for the given area type."""
if area_type == AreaType.STATE:
return STATE_AGI_CUTS
if area_type == AreaType.CD:
return CD_AGI_CUTS
raise ValueError(f"Unsupported area_type: {area_type}")


def get_num_agi_stubs(area_type: AreaType) -> int:
"""Return the number of non-total AGI stubs for the given area type."""
if area_type == AreaType.STATE:
return STATE_NUM_AGI_STUBS
if area_type == AreaType.CD:
return CD_NUM_AGI_STUBS
raise ValueError(f"Unsupported area_type: {area_type}")


# Faux area prefixes used for testing
FAUX_AREA_PREFIXES: List[str] = [
"xx",
Expand Down
Loading
Loading