Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 221 additions & 0 deletions tests/test_prepare_targets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
"""
Tests for the state target preparation pipeline.

Three test levels:
1. Unit: SOI share computation and rescaling
2. Integration: recipe expansion and target file structure
3. End-to-end: full pipeline on a single state
"""

import csv
from pathlib import Path

import numpy as np
import pandas as pd
import pytest

from tmd.areas.prepare.constants import (
ALL_SHARING_MAPPINGS,
AreaType,
)
from tmd.areas.prepare.census_population import get_state_population
from tmd.areas.prepare.extended_targets import build_extended_targets
from tmd.areas.prepare.soi_state_data import (
create_soilong,
create_state_base_targets,
)
from tmd.areas.prepare.target_file_writer import write_area_target_files
from tmd.areas.prepare.target_sharing import (
compute_soi_shares,
prepare_area_targets,
)

# --- Paths ---

REPO_ROOT = Path(__file__).parent.parent
_PREPARE = REPO_ROOT / "tmd" / "areas" / "prepare"
SOI_RAW_DIR = _PREPARE / "data" / "soi_states"
CACHED_ALLVARS = (
REPO_ROOT / "tmd" / "storage" / "output" / "cached_allvars.csv"
)
RECIPE_PATH = _PREPARE / "recipes" / "states.json"
VARMAP_PATH = _PREPARE / "recipes" / "state_variable_mapping.csv"

_EXCLUDE = {"US", "OA", "PR"}


# ---- Unit tests: SOI share computation ----


class TestSOIShares:
"""Test that SOI shares are computed correctly."""

@pytest.fixture(scope="class")
def shares_data(self):
"""Load SOI data and compute shares for 2022."""
soilong = create_soilong(SOI_RAW_DIR, years=[2022])
pop_df = get_state_population(2022)
base_targets = create_state_base_targets(soilong, pop_df, 2022)
return compute_soi_shares(base_targets, ALL_SHARING_MAPPINGS)

def test_shares_sum_to_one(self, shares_data):
"""Non-zero 51-state shares sum to 1.0."""
state_shares = shares_data[~shares_data["stabbr"].isin(_EXCLUDE)]
group_cols = [
"basesoivname",
"count",
"scope",
"fstatus",
"agistub",
]
group_sums = state_shares.groupby(group_cols)["soi_share"].sum()
nonzero = group_sums[group_sums > 0]
np.testing.assert_allclose(
nonzero.values,
1.0,
atol=1e-10,
err_msg="51-state shares should sum to 1.0",
)

def test_negative_shares_only_for_loss_variables(self, shares_data):
"""Negative shares only for variables with losses."""
neg = shares_data[shares_data["soi_share"] < 0]
loss_vars = {"26270", "00900"}
neg_vars = set(neg["basesoivname"].unique())
assert neg_vars.issubset(
loss_vars
), f"Unexpected negatives: {neg_vars - loss_vars}"

def test_mn_agi_share_reasonable(self, shares_data):
"""MN AGI share is roughly 1-3%."""
mn_agi = shares_data[
(shares_data["stabbr"] == "MN")
& (shares_data["basesoivname"] == "00100")
& (shares_data["count"] == 0)
& (shares_data["agistub"] == 0)
]
assert len(mn_agi) == 1
share = mn_agi["soi_share"].values[0]
assert 0.01 < share < 0.03

def test_xtot_equals_us_population(self):
"""XTOT 51-state sum equals US Census population."""
pop_df = get_state_population(2022)
us_pop = pop_df.loc[pop_df["stabbr"] == "US", "population"].values[0]
soilong = create_soilong(SOI_RAW_DIR, years=[2022])
base = create_state_base_targets(soilong, pop_df, 2022)
xtot = base[base["basesoivname"] == "XTOT"]
state_sum = xtot[~xtot["stabbr"].isin(_EXCLUDE)]["target"].sum()
assert state_sum == us_pop


# ---- Integration tests: target file structure ----


class TestTargetFileWriter:
"""Test recipe expansion produces correct target files."""

@pytest.fixture(scope="class")
def mn_targets(self, tmp_path_factory):
"""Run the full pipeline for MN."""
enhanced = prepare_area_targets(
area_type=AreaType.STATE,
area_data_year=2022,
)
enhanced = enhanced[
(enhanced["area"] == "MN") & ~enhanced["area"].isin(_EXCLUDE)
]
extra = build_extended_targets(
cached_allvars_path=CACHED_ALLVARS,
soi_year=2022,
areas=["MN"],
)
out_dir = tmp_path_factory.mktemp("targets")
result = write_area_target_files(
recipe_path=RECIPE_PATH,
enhanced_targets=enhanced,
variable_mapping_path=VARMAP_PATH,
output_dir=out_dir,
extra_targets=extra,
)
rows = pd.read_csv(out_dir / "mn_targets.csv")
return rows, result

def test_mn_target_count(self, mn_targets):
"""MN has ~178 targets (base + extended)."""
rows, _ = mn_targets
assert 170 <= len(rows) <= 185

def test_required_columns(self, mn_targets):
"""Target file has expected columns."""
rows, _ = mn_targets
expected = {
"varname",
"count",
"scope",
"agilo",
"agihi",
"fstatus",
"target",
}
assert set(rows.columns) == expected

def test_no_nan_targets(self, mn_targets):
"""No target value is NaN."""
rows, _ = mn_targets
assert not rows["target"].isna().any()

def test_xtot_is_first(self, mn_targets):
"""XTOT (population) is the first row."""
rows, _ = mn_targets
assert rows.iloc[0]["varname"] == "XTOT"
assert rows.iloc[0]["scope"] == 0

def test_extended_variables_present(self, mn_targets):
"""Extended target variables are present."""
rows, _ = mn_targets
varnames = set(rows["varname"])
for var in [
"capgains_net",
"e00600",
"c19200",
"e18400",
"e18500",
"eitc",
"ctc_total",
]:
assert var in varnames, f"{var} missing"


# ---- End-to-end test ----


def test_prepare_mn_end_to_end(tmp_path):
"""Full pipeline for MN produces valid output."""
enhanced = prepare_area_targets(
area_type=AreaType.STATE,
area_data_year=2022,
)
enhanced = enhanced[
(enhanced["area"] == "MN") & ~enhanced["area"].isin(_EXCLUDE)
]
extra = build_extended_targets(
cached_allvars_path=CACHED_ALLVARS,
soi_year=2022,
areas=["MN"],
)
result = write_area_target_files(
recipe_path=RECIPE_PATH,
enhanced_targets=enhanced,
variable_mapping_path=VARMAP_PATH,
output_dir=tmp_path,
extra_targets=extra,
)
assert "MN" in result
assert result["MN"] > 100
fpath = tmp_path / "mn_targets.csv"
assert fpath.exists()
with open(fpath, encoding="utf-8") as f:
rows = list(csv.DictReader(f))
assert len(rows) == result["MN"]
assert rows[0]["varname"] == "XTOT"
68 changes: 63 additions & 5 deletions tmd/areas/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,64 @@
# areas
# Area Weighting

Contains code and data used to generate sub-national area weights
files that can be used with the national input data and national
growfactors files, and code to examine the quality of the area
weights.
Generates sub-national area weight files from national PUF-based
microdata. Area weights let Tax-Calculator produce state-level
estimates using national input data and growfactors.

## Preparing State Targets

```bash
# All 51 states (50 states + DC), ~4 seconds:
python -m tmd.areas.prepare_targets --scope states

# Specific states:
python -m tmd.areas.prepare_targets --scope MN,CA,TX

# Use 2021 SOI shares:
python -m tmd.areas.prepare_targets --scope states --year 2021
```

**Prerequisite**: TMD national data must exist (`make tmd_files`).

Output: one CSV per state in `tmd/areas/targets/states/`.

## How Targets Work

Each target constrains a weighted sum to match a state-level value.
Targets combine two data sources:

- **TMD national totals** provide the level (weighted sums from the
national PUF microdata).
- **IRS SOI state data** provides the geographic distribution (each
state's share of the US total, by AGI bin and filing status).

Formula: `state_target = TMD_national × (state_SOI / US_SOI)`

Extended targets also use **Census state/local finance data** for
SALT distribution and **SOI credit data** for EITC/CTC.

## Target Composition (~178 per state)

**Base targets** (recipe-driven, all 10 AGI bins):
AGI amounts, total return counts, return counts by filing status,
wages (amount + nonzero count), taxable interest, pensions, Social
Security, SALT deduction, partnership/S-corp income.

**Extended targets** (SOI/Census-shared, high-income bins only):
Taxable pensions, taxable Social Security, IRA distributions, net
capital gains, dividends, business income, mortgage interest,
charitable contributions, SALT by source (Census), EITC, CTC.

Filing-status count targets are excluded from the $1M+ AGI bin to
avoid excessive weight distortion on small cells.

## Pipeline Modules

| Module | Purpose |
|--------|---------|
| `prepare_targets.py` | CLI entry point |
| `prepare/soi_state_data.py` | SOI state CSV ingestion |
| `prepare/target_sharing.py` | TMD × SOI share computation |
| `prepare/target_file_writer.py` | Recipe expansion, CSV output |
| `prepare/extended_targets.py` | Census/SOI extended targets |
| `prepare/constants.py` | AGI bins, mappings, metadata |
| `prepare/census_population.py` | State population data |
10 changes: 10 additions & 0 deletions tmd/areas/prepare/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""
Area data preparation package.

Converts IRS SOI data into area-specific targets for states and
Congressional Districts, replacing the R/Quarto pipeline.
"""

from pathlib import Path

PREPARE_FOLDER = Path(__file__).parent
86 changes: 86 additions & 0 deletions tmd/areas/prepare/census_population.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""
Census state population data for area target preparation.

Provides state population estimates from the Census Bureau Population
Estimates Program (PEP), stored in a JSON file under ``data/``.

A user-supplied CSV can override the defaults.

Note: CD (Congressional District) population support will be added
in a future PR.
"""

import json
from pathlib import Path
from typing import Dict, Optional

import pandas as pd

_DATA_DIR = Path(__file__).parent / "data"


def _load_population_json(filename: str) -> Dict[str, Dict[str, int]]:
"""Load a population JSON file and return {year_str: {area: pop}}."""
path = _DATA_DIR / filename
with open(path, encoding="utf-8") as f:
data = json.load(f)
# Filter out metadata keys (those starting with _)
return {k: v for k, v in data.items() if not k.startswith("_")}


def get_state_population(
year: int,
csv_path: Optional[Path] = None,
) -> pd.DataFrame:
"""
Return state population as DataFrame with columns (stabbr, population).

Parameters
----------
year : int
Calendar year for the population estimate.
csv_path : Path, optional
Path to a CSV with columns ``stabbr`` and a population column
(named ``pop{year}`` or ``population``). If provided, this
overrides the default data.

Returns
-------
pd.DataFrame
Columns: stabbr (str), population (int).
Includes 50 states, DC, PR, and US.
"""
if csv_path is not None:
return _read_population_csv(csv_path, year)
all_years = _load_population_json("state_populations.json")
year_str = str(year)
if year_str not in all_years:
raise ValueError(
f"No state population data for {year}. "
f"Available years: {sorted(all_years.keys())}. "
f"Supply a csv_path to use custom data."
)
pop = all_years[year_str]
df = pd.DataFrame(list(pop.items()), columns=["stabbr", "population"])
df["population"] = df["population"].astype(int)
return df.sort_values("stabbr").reset_index(drop=True)


def _read_population_csv(csv_path: Path, year: int) -> pd.DataFrame:
"""Read state population CSV, normalise column names."""
df = pd.read_csv(csv_path)
pop_col = f"pop{year}"
if pop_col in df.columns:
df = df.rename(columns={pop_col: "population"})
elif "population" not in df.columns:
pop_cols = [c for c in df.columns if c.startswith("pop")]
if len(pop_cols) == 1:
df = df.rename(columns={pop_cols[0]: "population"})
else:
raise ValueError(
f"Cannot find population column in {csv_path}. "
f"Expected 'pop{year}' or 'population'."
)
df = df[["stabbr", "population"]].copy()
df["population"] = df["population"].astype(int)
return df.sort_values("stabbr").reset_index(drop=True)
Loading
Loading