diff --git a/tests/expected_area_wgt_2021_data.yaml b/tests/expected_area_wgt_2021_data.yaml deleted file mode 100644 index 6a1edbc1..00000000 --- a/tests/expected_area_wgt_2021_data.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# variable statistics that are targeted in the area weights optimization: -popall: 33 # # million -e00300: 20 # $ billion -e00900: 30 # $ billion -e00200: 1000 # $ billion -e02000: 30 # $ billion -e02400: 60 # $ billion -c00100: 1200 # $ billion -agihic: 10 # # thousand -e00400: 2 # $ billion -e00600: 8 # $ billion -e00650: 7 # $ billion -e01700: 12 # $ billion -e02300: 10 # $ billion -e17500: 5 # $ billion -e18400: 10 # $ billion -e18500: 10 # $ billion -# variable statistics that are not targeted in the area weights optimization: -# ... none available in the faux xx area ... diff --git a/tests/expected_area_wgt_2022_data.yaml b/tests/expected_area_wgt_2022_data.yaml deleted file mode 100644 index 6a1edbc1..00000000 --- a/tests/expected_area_wgt_2022_data.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# variable statistics that are targeted in the area weights optimization: -popall: 33 # # million -e00300: 20 # $ billion -e00900: 30 # $ billion -e00200: 1000 # $ billion -e02000: 30 # $ billion -e02400: 60 # $ billion -c00100: 1200 # $ billion -agihic: 10 # # thousand -e00400: 2 # $ billion -e00600: 8 # $ billion -e00650: 7 # $ billion -e01700: 12 # $ billion -e02300: 10 # $ billion -e17500: 5 # $ billion -e18400: 10 # $ billion -e18500: 10 # $ billion -# variable statistics that are not targeted in the area weights optimization: -# ... none available in the faux xx area ... diff --git a/tests/test_area_weights.py b/tests/test_area_weights.py deleted file mode 100644 index 58cba41d..00000000 --- a/tests/test_area_weights.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Tests of tmd/areas/create_area_weights.py script. -""" - -import yaml -import taxcalc as tc -from tmd.storage import STORAGE_FOLDER -from tmd.imputation_assumptions import TAXYEAR, CREDIT_CLAIMING -from tmd.areas import AREAS_FOLDER -from tmd.areas.create_area_weights import create_area_weights_file -from tests.conftest import create_tmd_records - -YEAR = TAXYEAR - - -def test_area_xx(tests_folder): - """ - Optimize national weights for faux xx area using the faux xx area targets - and compare actual Tax-Calculator results with expected results when - using area weights along with national input data and growfactors. - """ - rc = create_area_weights_file("xx", write_log=False, write_file=True) - assert rc == 0, "create_areas_weights_file has non-zero return code" - # compare actual vs expected results for faux area xx - # ... instantiate Tax-Calculator object for area - pol = tc.Policy() - pol.implement_reform(CREDIT_CLAIMING) - rec = create_tmd_records( - data_path=STORAGE_FOLDER / "output" / "tmd.csv.gz", - weights_path=AREAS_FOLDER / "weights" / "xx_tmd_weights.csv.gz", - growfactors_path=STORAGE_FOLDER / "output" / "tmd_growfactors.csv", - ) - sim = tc.Calculator(policy=pol, records=rec) - # ... calculate tax variables for YEAR - sim.advance_to_year(YEAR) - sim.calc_all() - vdf = sim.dataframe([], all_vars=True) - # ... calculate actual results and store in act dictionary - wght = vdf.s006 * (vdf.data_source == 1) # PUF weights - act = { - "popall": (vdf.s006 * vdf.XTOT).sum() * 1e-6, - "e00300": (wght * vdf.e00300).sum() * 1e-9, - "e00900": (wght * vdf.e00900).sum() * 1e-9, - "e00200": (wght * vdf.e00200).sum() * 1e-9, - "e02000": (wght * vdf.e02000).sum() * 1e-9, - "e02400": (wght * vdf.e02400).sum() * 1e-9, - "c00100": (wght * vdf.c00100).sum() * 1e-9, - "agihic": (wght * (vdf.c00100 >= 1e6)).sum() * 1e-3, - "e00400": (wght * vdf.e00400).sum() * 1e-9, - "e00600": (wght * vdf.e00600).sum() * 1e-9, - "e00650": (wght * vdf.e00650).sum() * 1e-9, - "e01700": (wght * vdf.e01700).sum() * 1e-9, - "e02300": (wght * vdf.e02300).sum() * 1e-9, - "e17500": (wght * vdf.e17500).sum() * 1e-9, - "e18400": (wght * vdf.e18400).sum() * 1e-9, - "e18500": (wght * vdf.e18500).sum() * 1e-9, - } - # ... read expected results into exp dictionary - exp_path = tests_folder / f"expected_area_wgt_{TAXYEAR}_data.yaml" - with open(exp_path, "r", encoding="utf-8") as efile: - exp = yaml.safe_load(efile.read()) - # compare actual with expected results - default_rtol = 0.005 - rtol = { - "c00100": 0.008, - "e00200": 0.008, - } - if set(act.keys()) != set(exp.keys()): - print("sorted(act.keys())=", sorted(act.keys())) - print("sorted(exp.keys())=", sorted(exp.keys())) - raise ValueError("act.keys() != exp.keys()") - emsg = "" - for res in exp.keys(): - reldiff = act[res] / exp[res] - 1 - reltol = rtol.get(res, default_rtol) - ok = abs(reldiff) < reltol - if not ok: - emsg += ( - f"FAIL:res,act,exp,rdiff,rtol= {res} {act[res]:.5f}" - f" {exp[res]:.5f} {reldiff:.4f} {reltol:.4f}\n" - ) - if emsg: - print(emsg) - raise ValueError("ACT vs EXP diffs in test_area_weights") diff --git a/tests/test_solve_weights.py b/tests/test_solve_weights.py new file mode 100644 index 00000000..6ec1ca27 --- /dev/null +++ b/tests/test_solve_weights.py @@ -0,0 +1,255 @@ +""" +Tests for the state weight solver pipeline. + +Tests: + - Solver on faux xx area + - Quality report log parser + - CLI scope parsing + - Batch area filtering +""" + +import io +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from tmd.areas import AREAS_FOLDER +from tmd.areas.batch_weights import _filter_areas +from tmd.areas.create_area_weights import ( + AREA_CONSTRAINT_TOL, + _build_constraint_matrix, + _drop_impossible_targets, + _load_taxcalc_data, + _solve_area_qp, + create_area_weights_file, +) +from tmd.areas.quality_report import ( + _humanize_desc, + parse_log, +) +from tmd.areas.solve_weights import ( + _parse_scope, +) +from tmd.imputation_assumptions import TAXYEAR + +# --- Solver test on faux xx area --- + + +def test_solver_xx(): + """ + Solve faux xx area and verify targets are hit. + """ + # xx targets are in the flat targets/ directory + target_dir = AREAS_FOLDER / "targets" + + with tempfile.TemporaryDirectory() as tmpdir: + weight_dir = Path(tmpdir) + rc = create_area_weights_file( + "xx", + write_log=True, + write_file=True, + target_dir=target_dir, + weight_dir=weight_dir, + ) + assert rc == 0, "create_area_weights_file returned non-zero" + + # Verify weights file was created + wpath = weight_dir / "xx_tmd_weights.csv.gz" + assert wpath.exists(), "Weight file not created" + + # Verify weights file has expected columns + wdf = pd.read_csv(wpath) + assert f"WT{TAXYEAR}" in wdf.columns + assert f"WT{TAXYEAR + 1}" in wdf.columns + + # Verify all weights are non-negative + assert (wdf[f"WT{TAXYEAR}"] >= 0).all(), "Negative weights found" + + # Verify log file was created + logpath = weight_dir / "xx.log" + assert logpath.exists(), "Log file not created" + + # Verify log contains solver status + log_text = logpath.read_text() + assert "Solver status:" in log_text + assert "TARGET ACCURACY" in log_text + + +def test_solver_xx_targets_hit(): + """ + Verify the solver actually hits the xx area targets + within the constraint tolerance. + """ + target_dir = AREAS_FOLDER / "targets" + out = io.StringIO() + + vdf = _load_taxcalc_data() + B_csc, targets, labels, _pop_share = _build_constraint_matrix( + "xx", vdf, out, target_dir=target_dir + ) + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + n_records = B_csc.shape[1] + x_opt, _s_lo, _s_hi, info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + out=out, + ) + + # Check that solver succeeded + status = info["status"] + assert "Solved" in status, f"Solver status: {status}" + + # Check all targets hit within tolerance + achieved = np.asarray(B_csc @ x_opt).ravel() + rel_errors = np.abs(achieved - targets) / np.maximum(np.abs(targets), 1.0) + eps = 1e-9 + n_violated = int((rel_errors > AREA_CONSTRAINT_TOL + eps).sum()) + assert n_violated == 0, ( + f"{n_violated} targets violated; " + f"max error = {rel_errors.max() * 100:.3f}%" + ) + + +# --- Quality report log parser --- + + +_V = "c00100/cnt=1/scope=0" +_V1 = ( # noqa: E501 + f" 0.500% | target= 12345 | achieved= 12407" + f" | {_V}/agi=[-9e+99,1.0)/fs=0" +) +_V2 = ( + f" 0.489% | target= 567 | achieved= 570" + f" | {_V}/agi=[1e+06,9e+99)/fs=1" +) +_V3 = ( + f" 0.478% | target= 1234 | achieved= 1240" + f" | {_V}/agi=[1e+06,9e+99)/fs=2" +) + + +def test_parse_log_solved(): + """Test log parser with a synthetic solved log.""" + _hit = " targets hit: 175/178 (tolerance: +/-0.5% + eps)" + _mult = ( + " min=0.000000, p5=0.450000," + + " median=0.980000," + + " p95=1.650000, max=45.000000" + ) + _d1 = " [ 0.0000, 0.0000): 2662 ( 1.24%)" + _d2 = " [ 0.0000, 0.1000): 5000 ( 2.33%)" + log_content = "\n".join( + [ + "Solver status: Solved", + "Iterations: 42", + "Solve time: 3.14s", + "TARGET ACCURACY (178 targets):", + " mean |relative error|: 0.001234", + " max |relative error|: 0.004567", + _hit, + " VIOLATED: 3 targets", + _V1, + _V2, + _V3, + "MULTIPLIER DISTRIBUTION:", + _mult, + " RMSE from 1.0: 0.550000", + " distribution (n=214377):", + _d1, + _d2, + "ALL CONSTRAINTS SATISFIED WITHOUT SLACK", + "", + ] + ) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".log", delete=False + ) as f: + f.write(log_content) + f.flush() + result = parse_log(Path(f.name)) + + assert result["status"] == "Solved" + assert result["solve_time"] == pytest.approx(3.14) + assert result["mean_err"] == pytest.approx(0.001234) + assert result["max_err"] == pytest.approx(0.004567) + assert result["targets_hit"] == 175 + assert result["targets_total"] == 178 + assert result["n_violated"] == 3 + assert result["w_min"] == pytest.approx(0.0) + assert result["w_median"] == pytest.approx(0.98) + assert result["w_rmse"] == pytest.approx(0.55) + assert result["n_records"] == 214377 + assert len(result["violated_details"]) == 3 + + +def test_parse_log_missing(): + """Test log parser with non-existent file.""" + result = parse_log(Path("/nonexistent/path.log")) + assert result["status"] == "NO LOG" + + +# --- Scope parsing --- + + +def test_solve_weights_parse_scope_states(): + """Test scope parsing for 'states'.""" + assert _parse_scope("states") is None + assert _parse_scope("all") is None + + +def test_solve_weights_parse_scope_specific(): + """Test scope parsing for specific states.""" + result = _parse_scope("MN,CA,TX") + assert result == ["MN", "CA", "TX"] + + +def test_solve_weights_parse_scope_excludes(): + """Test scope parsing excludes PR, US, OA.""" + result = _parse_scope("MN,PR,US,OA,CA") + assert result == ["MN", "CA"] + + +# --- Batch area filtering --- + + +def test_filter_areas_states(): + """Test batch filter for states.""" + areas = ["al", "ca", "mn", "mn01", "xx"] + assert _filter_areas(areas, "states") == [ + "al", + "ca", + "mn", + ] + + +def test_filter_areas_cds(): + """Test batch filter for CDs.""" + areas = ["al", "ca", "mn01", "mn02"] + assert _filter_areas(areas, "cds") == ["mn01", "mn02"] + + +def test_filter_areas_specific(): + """Test batch filter for specific areas.""" + areas = ["al", "ca", "mn", "tx"] + assert _filter_areas(areas, "mn,tx") == ["mn", "tx"] + + +# --- Humanize description --- + + +def test_humanize_desc(): + """Test human-readable label generation.""" + desc = "c00100/cnt=1/scope=1" + "/agi=[500000.0,1000000.0)/fs=4" + result = _humanize_desc(desc) + assert "c00100" in result + assert "returns" in result + assert "HoH" in result + assert "$500K" in result diff --git a/tmd/areas/AREA_WEIGHTING_LESSONS.md b/tmd/areas/AREA_WEIGHTING_LESSONS.md new file mode 100644 index 00000000..e0d891dc --- /dev/null +++ b/tmd/areas/AREA_WEIGHTING_LESSONS.md @@ -0,0 +1,213 @@ +# Area Weighting: Lessons Learned + +Practical guidance from developing the state weight optimization +pipeline. Intended for future maintainers and anyone extending +this to new geographies (e.g., Congressional districts). + +## The optimization problem + +For each sub-national area, we find weight multipliers `x_i` for +every record such that weighted sums match area targets within a +tolerance band, while keeping multipliers close to 1.0 +(population-proportional). + +``` +minimize sum((x_i - 1)^2) [stay close to proportional] +subject to target*(1-tol) <= Bx <= target*(1+tol) [hit targets] + 0 <= x_i <= multiplier_max [bounds] +``` + +Solved independently per area using Clarabel (constrained QP with +elastic slack for infeasibility). + +## Key parameters and what they do + +| Parameter | Default | Effect | +|-----------|---------|--------| +| `AREA_CONSTRAINT_TOL` | 0.005 (0.5%) | Target tolerance band. Matches national reweighting. | +| `AREA_MULTIPLIER_MAX` | 25.0 | Per-record upper bound on weight multiplier. Most important lever for controlling exhaustion. | +| `AREA_SLACK_PENALTY` | 1e6 | Penalty on constraint slack. Very high = hard constraints. | +| `weight_penalty` | 1.0 | Penalty on `(x-1)^2`. Higher values keep multipliers closer to 1.0. | + +## Weight exhaustion + +**Definition**: For each record, exhaustion = (sum of all area +weights) / (national weight). A value of 1.0 means the record's +national weight is fully allocated across areas. Values above 1.0 +mean the record is "oversubscribed" — used more in total than its +national weight warrants. + +**Why it matters**: High exhaustion means a small number of records +are doing heavy lifting across many states, creating fragile +solutions where one record's characteristics drive multiple states' +estimates. + +**What drives it**: Rare high-income PUF records with small national +weights (s006 ~ 10-50) get pulled by many states to hit their +high-AGI-bin targets. These are typically wealthy MFJ households +with large investment income (interest, dividends, capital gains, +partnership income). + +### Exhaustion statistics (2022 SOI, 178 targets/state) + +| Percentile | Exhaustion | +|------------|-----------| +| Median | 1.007 | +| p99 | 2.0 | +| p99.9 | 4.4 | +| Max | 25.2 (at mult_max=100) | +| Max | 16.6 (at mult_max=25) | + +Only ~150 records exceed 5x. The problem is concentrated in +the extreme tail. + +## Parameter sweep results (2023 tax year, 51 states) + +Tested `multiplier_max` x `weight_penalty` grid. Key finding: +**`weight_penalty` has no effect** on exhaustion, wRMSE, or %zero +— it only increases target violations. The solver reaches the +same weight structure regardless; higher penalty just makes it +fail to meet more targets. + +**`multiplier_max` is the only effective lever:** + +| mult_max | Violations | MaxViol% | wRMSE | %zero | MaxExh | >10x | +|----------|-----------|---------|-------|-------|--------|------| +| 100 | 33 | 0.50% | 0.594 | 7.3% | 25.2 | 16 | +| **25** | **35** | **0.50%** | **0.609** | **7.8%** | **16.6** | **15** | +| 15 | 210 | 100% | 0.601 | 9.9% | 11.9 | 3 | +| 10 | 402 | 100% | 0.625 | 12.2% | 8.8 | 0 | + +**`mult_max=25` is the sweet spot**: only 2 extra violations +vs baseline, max violation still at the 0.50% tolerance boundary, +but max exhaustion drops 34% (25.2 to 16.6). Below 25, targets +become infeasible (100% violations appear). + +## Two-pass exhaustion limiting (not recommended for production) + +We tested an iterative approach: solve unconstrained, compute +exhaustion, set per-record caps, re-solve. With `max_exhaustion=5`: + +- Pass 1: 33 violations (normal) +- Pass 2: **8,979 violations**, max **100%** — catastrophic + +The proportional cap scaling was too aggressive. Records at 25x +got scaled to 20% of their multiplier, making targets infeasible. +The approach is fragile and hard to tune. + +**Recommendation**: Use `multiplier_max` (single pass) rather +than iterative exhaustion capping. Simpler, more robust, and +no slower. + +## Dual variable analysis (constraint cost) + +Clarabel's dual variables reveal which constraints are expensive +to satisfy — a target can be perfectly hit but still cause massive +weight distortion. This is invisible from violations alone. + +**Finding**: The $1M+ AGI bin filing-status count targets (single, +MFJ, HoH) are essentially the **only expensive constraints**. Their +dual costs are 6-8 orders of magnitude larger than all other targets. +Extended targets (capital gains, pensions, charitable, etc.) have +near-zero dual cost — they're "free" to add. + +**Action taken**: Excluded filing-status counts from the $1M+ bin. +Kept total return count as an anchor. This eliminated virtually all +constraint cost while maintaining full target coverage. + +## SALT targeting + +Use **actual SALT** (`c18300`, the post-$10K-cap amount) rather than +uncapped potential SALT (`e18400`/`e18500`). This is apples-to-apples: +Tax-Calculator's actual SALT under current law, shared using observed +SOI actual SALT by state. + +| Metric | e18400/e18500 | c18300 | +|--------|--------------|--------| +| Violations | 85 | 75 | +| wRMSE avg | 0.479 | 0.289 | +| SALT aggregation error | -6.23% | -0.18% | + +## Bystander variables (untargeted collateral distortion) + +Area weighting optimizes ~178 targets per state. Variables not +in the target set can be distorted as a side effect — "innocent +bystanders" pulled by weight adjustments aimed at other variables. + +The quality report includes a bystander check showing cross-state +aggregation error (sum-of-states vs national) for untargeted +variables. Variables with >2% distortion are flagged. + +### Observed bystanders (2023, mult_max=25) + +| Variable | Diff% | Why | +|----------|-------|-----| +| Student loan int (e19200) | -10.5% | Concentrated on middle-income records that get under-weighted | +| AMT (c09600) | -10.3% | Very small base ($1.1B); volatile under any reweighting | +| Tax-exempt int (e00400) | +7.4% | On same wealthy records that get over-weighted for AGI targets | +| Qualified dividends (e00650) | -4.1% | Total dividends (e00600) targeted but not qualified portion | +| Unemployment comp (e02300) | -4.0% | Low-income variable squeezed when high-income records upweighted | +| Medical expenses (e17500) | -2.9% | Itemized deduction bystander | +| Total credits (c07100) | -2.2% | EITC/CTC targeted but not total credits | + +### Well-behaved bystanders (<1%) + +Income tax, payroll tax, standard deduction, itemized total, +Sch C, IRA distributions, taxable pensions, Sch E, total persons, +children under 17. These are either targeted directly or highly +correlated with targeted variables. + +### When to worry + +Bystander distortion matters when the variable is **policy-relevant +at the state level** and **not closely correlated** with any targeted +variable. Student loan interest at -10.5% would matter for a state- +level student debt analysis; AMT at -10.3% on a $1.1B base is less +consequential. + +If a bystander becomes important for a specific analysis, the fix is +to add it (or a correlated variable) as a target. The dual variable +analysis showed that extended targets are essentially "free" — near- +zero constraint cost — so adding targets is low-risk as long as they +don't conflict with existing targets in small AGI bins. + +## OA (Other Areas) share rescaling + +SOI's "Other Areas" category (~0.5% of returns, covers territories +and overseas filers) is excluded from state targeting. Raw SOI +shares (state/US) are rescaled so that the 51 state shares sum +to 1.0 for each variable-AGI-bin combination. + +## Guidance for Congressional Districts + +CDs have not been implemented yet. Expected differences from states: + +- **435 areas** (vs 51) — grid search infeasible; use state-derived + parameter settings +- **9 AGI bins** (vs 10) — no $1M+ separate bin in SOI CD data +- **Smaller populations** — more CDs will hit multiplier ceilings; + may need different `multiplier_max` +- **Crosswalk complexity** — SOI uses 117th Congress boundaries for + both 2021 and 2022 data; need geocorr crosswalk to map to 118th + Congress boundaries +- **Exhaustion will be worse** — 435 areas competing for the same + records means much higher potential exhaustion; `multiplier_max` + may need to be lower + +## Running the parameter sweep + +The sweep utility tests combinations of solver parameters on all +states and reports target accuracy, weight distortion, and exhaustion: + +```bash +python -m tmd.areas.sweep_params +``` + +Edit `GRID_MULTIPLIER_MAX` and `GRID_WEIGHT_PENALTY` in +`tmd/areas/sweep_params.py` to change the search grid. Each +combination solves all 51 states (~3-4 minutes with 8 workers). + +**Should be rerun when**: +- The tax year changes (different TMD data, different SOI shares) +- The target recipe changes materially +- Extending to a new area type (CDs) diff --git a/tmd/areas/README.md b/tmd/areas/README.md index 3b49816c..7a489a27 100644 --- a/tmd/areas/README.md +++ b/tmd/areas/README.md @@ -51,14 +51,59 @@ charitable contributions, SALT by source (Census), EITC, CTC. Filing-status count targets are excluded from the $1M+ AGI bin to avoid excessive weight distortion on small cells. +## Solving for State Weights + +```bash +# All 51 states, 8 parallel workers: +python -m tmd.areas.solve_weights --scope states --workers 8 + +# Specific states: +python -m tmd.areas.solve_weights --scope MN,CA,TX --workers 4 +``` + +Uses the Clarabel constrained QP solver to find per-record weight +multipliers that hit each state's targets within 0.5% tolerance. + +Output: weight files in `tmd/areas/weights/states/` and solver +logs alongside them. + +## Quality Report + +```bash +python -m tmd.areas.quality_report +python -m tmd.areas.quality_report --scope CA,NY +``` + +Cross-state summary: solve status, target accuracy, weight +distortion, weight exhaustion, and national aggregation checks. + ## Pipeline Modules +**Target preparation** (PR 1): + | Module | Purpose | |--------|---------| | `prepare_targets.py` | CLI entry point | | `prepare/soi_state_data.py` | SOI state CSV ingestion | -| `prepare/target_sharing.py` | TMD × SOI share computation | +| `prepare/target_sharing.py` | TMD x SOI share computation | | `prepare/target_file_writer.py` | Recipe expansion, CSV output | | `prepare/extended_targets.py` | Census/SOI extended targets | | `prepare/constants.py` | AGI bins, mappings, metadata | | `prepare/census_population.py` | State population data | + +**Weight solving** (PR 2): + +| Module | Purpose | +|--------|---------| +| `solve_weights.py` | CLI entry point | +| `create_area_weights_clarabel.py` | Clarabel QP solver | +| `batch_weights.py` | Parallel batch runner | +| `quality_report.py` | Cross-state diagnostics | +| `sweep_params.py` | Parameter grid search utility | + +## Lessons Learned + +See [AREA_WEIGHTING_LESSONS.md](AREA_WEIGHTING_LESSONS.md) for +practical guidance on parameter tuning, weight exhaustion, SALT +targeting, dual variable analysis, and recommendations for +extending to Congressional districts. diff --git a/tmd/areas/batch_weights.py b/tmd/areas/batch_weights.py new file mode 100644 index 00000000..9a5c1f45 --- /dev/null +++ b/tmd/areas/batch_weights.py @@ -0,0 +1,371 @@ +# pylint: disable=import-outside-toplevel,global-statement +""" +Batch area weight optimization — parallel processing for all areas. + +TMD data loaded once per worker (not once per area). +Progress reporting with ETA. +Uses concurrent.futures for clean parallel execution. + +Usage: + python -m tmd.areas.batch_weights --scope states --workers 8 + python -m tmd.areas.batch_weights --scope MN,CA,TX --workers 4 +""" + +import argparse +import io +import re +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import numpy as np +import pandas as pd +import yaml + +# Module-level cache for TMD data (one per worker process) +_WORKER_VDF = None +_WORKER_POP = None +_WORKER_TARGET_DIR = None +_WORKER_WEIGHT_DIR = None + + +def _init_worker(target_dir=None, weight_dir=None): + """Load TMD data once per worker process.""" + global _WORKER_VDF, _WORKER_POP + global _WORKER_TARGET_DIR, _WORKER_WEIGHT_DIR + if target_dir is not None: + _WORKER_TARGET_DIR = Path(target_dir) + if weight_dir is not None: + _WORKER_WEIGHT_DIR = Path(weight_dir) + if _WORKER_VDF is not None: + return + from tmd.areas.create_area_weights import ( + POPFILE_PATH, + _load_taxcalc_data, + ) + + _WORKER_VDF = _load_taxcalc_data() + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + _WORKER_POP = yaml.safe_load(pf.read()) + + +def _solve_one_area(area): + """ + Solve area weights for one area using cached worker data. + + Returns (area, elapsed, n_targets, n_violated, status, max_viol_pct). + """ + _init_worker() + from tmd.areas.create_area_weights import ( + AREA_CONSTRAINT_TOL, + AREA_MAX_ITER, + AREA_MULTIPLIER_MAX, + AREA_MULTIPLIER_MIN, + AREA_SLACK_PENALTY, + FIRST_YEAR, + LAST_YEAR, + _build_constraint_matrix, + _drop_impossible_targets, + _print_multiplier_diagnostics, + _print_slack_diagnostics, + _print_target_diagnostics, + _read_params, + _solve_area_qp, + ) + + t0 = time.time() + out = io.StringIO() + + # Build and solve + vdf = _WORKER_VDF + tgt_dir = _WORKER_TARGET_DIR + wgt_dir = _WORKER_WEIGHT_DIR + params = _read_params(area, out, target_dir=tgt_dir) + constraint_tol = params.get( + "constraint_tol", + params.get("target_ratio_tolerance", AREA_CONSTRAINT_TOL), + ) + slack_penalty = params.get("slack_penalty", AREA_SLACK_PENALTY) + max_iter = params.get("max_iter", AREA_MAX_ITER) + multiplier_min = params.get("multiplier_min", AREA_MULTIPLIER_MIN) + multiplier_max = params.get("multiplier_max", AREA_MULTIPLIER_MAX) + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=tgt_dir + ) + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + n_records = B_csc.shape[1] + + # Check for per-record multiplier caps (from exhaustion limiting) + caps_path = wgt_dir / f"{area}_record_caps.npy" + if caps_path.exists(): + record_caps = np.load(caps_path) + multiplier_max = np.minimum(multiplier_max, record_caps) + n_capped = int((record_caps < AREA_MULTIPLIER_MAX).sum()) + out.write( + f"USING PER-RECORD MULTIPLIER CAPS" + f" ({n_capped} records capped)\n" + ) + + x_opt, s_lo, s_hi, info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=constraint_tol, + slack_penalty=slack_penalty, + max_iter=max_iter, + multiplier_min=multiplier_min, + multiplier_max=multiplier_max, + out=out, + ) + + # Diagnostics + n_violated = _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out + ) + _print_multiplier_diagnostics(x_opt, out) + _print_slack_diagnostics(s_lo, s_hi, targets, labels, out) + + # Compute max violation percentage for summary + achieved = np.asarray(B_csc @ x_opt).ravel() + rel_errors = np.abs(achieved - targets) / np.maximum(np.abs(targets), 1.0) + eps = 1e-9 + viol_mask = rel_errors > constraint_tol + eps + max_viol_pct = ( + float(rel_errors[viol_mask].max() * 100) if viol_mask.any() else 0.0 + ) + + # Write log + logpath = wgt_dir / f"{area}.log" + logpath.parent.mkdir(parents=True, exist_ok=True) + with open(logpath, "w", encoding="utf-8") as f: + f.write(out.getvalue()) + + # Write weights file + w0 = pop_share * vdf.s006.values + wght_area = x_opt * w0 + + wdict = {f"WT{FIRST_YEAR}": wght_area} + cum_pop_growth = 1.0 + pop = _WORKER_POP + for year in range(FIRST_YEAR + 1, LAST_YEAR + 1): + annual_pop_growth = pop[year] / pop[year - 1] + cum_pop_growth *= annual_pop_growth + wdict[f"WT{year}"] = wght_area * cum_pop_growth + + wdf = pd.DataFrame.from_dict(wdict) + awpath = wgt_dir / f"{area}_tmd_weights.csv.gz" + wdf.to_csv( + awpath, + index=False, + float_format="%.5f", + compression="gzip", + ) + + elapsed = time.time() - t0 + return ( + area, + elapsed, + len(targets), + n_violated, + info["status"], + max_viol_pct, + ) + + +def _list_target_areas(target_dir=None): + """Return sorted list of area codes with target files.""" + from tmd.areas.create_area_weights import STATE_TARGET_DIR + + tfolder = target_dir or STATE_TARGET_DIR + tpaths = sorted(tfolder.glob("*_targets.csv")) + areas = [tpath.name.split("_")[0] for tpath in tpaths] + return areas + + +def _filter_areas(areas, area_filter): + """Filter areas by type: 'states', 'cds', or 'all'.""" + if area_filter == "all": + return areas + if area_filter == "states": + return [a for a in areas if len(a) == 2 and not re.match(r"[x-z]", a)] + if area_filter == "cds": + return [a for a in areas if len(a) > 2] + # Treat as comma-separated list + requested = [a.strip() for a in area_filter.split(",")] + return [a for a in areas if a in requested] + + +def run_batch( + num_workers=1, + area_filter="all", + force=False, + target_dir=None, + weight_dir=None, +): + """ + Run area weight optimization for multiple areas in parallel. + + Parameters + ---------- + num_workers : int + Number of parallel worker processes. + area_filter : str + 'states', 'cds', 'all', or comma-separated area codes. + force : bool + If True, recompute all areas even if up-to-date. + target_dir : Path, optional + Directory containing target CSVs. + weight_dir : Path, optional + Directory for weight output. + """ + from tmd.areas.create_area_weights import ( + STATE_TARGET_DIR, + STATE_WEIGHT_DIR, + ) + + if target_dir is None: + target_dir = STATE_TARGET_DIR + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + all_areas = _list_target_areas(target_dir=target_dir) + areas = _filter_areas(all_areas, area_filter) + + if not areas: + print("No areas to process.") + return + + # Filter to out-of-date areas unless force=True + if not force: + from tmd.areas.make_all import time_of_newest_other_dependency + + newest_dep = time_of_newest_other_dependency() + todo = [] + for area in areas: + wpath = weight_dir / f"{area}_tmd_weights.csv.gz" + tpath = target_dir / f"{area}_targets.csv" + if wpath.exists(): + wtime = wpath.stat().st_mtime + ttime = tpath.stat().st_mtime + if wtime > max(newest_dep, ttime): + continue + todo.append(area) + areas = todo + + if not areas: + print("All areas up-to-date. Use --force to recompute.") + return + + # Count targets from first area's CSV + first_tpath = target_dir / f"{areas[0]}_targets.csv" + n_targets = ( + sum( + 1 + for line in first_tpath.read_text().splitlines() + if line.strip() and not line.startswith("#") + ) + - 1 + ) # subtract header + n = len(areas) + print( + f"Processing {n} areas" + f" (up to {n_targets} targets each)" + f" with {num_workers} workers..." + ) + print( + "(Areas shown in completion order, which varies with" + " parallel workers.)" + ) + + # Ensure weights directory exists + weight_dir.mkdir(parents=True, exist_ok=True) + + t_start = time.time() + completed = 0 + violated_areas = [] + worst_viol_pct = 0.0 + max_id_width = max(len(a) for a in areas) + + with ProcessPoolExecutor( + max_workers=num_workers, + initializer=_init_worker, + initargs=(str(target_dir), str(weight_dir)), + ) as executor: + futures = { + executor.submit(_solve_one_area, area): area for area in areas + } + for future in as_completed(futures): + area = futures[future] + try: + area_code, _, _, n_viol, _, mv_pct = future.result() + completed += 1 + + if n_viol > 0: + violated_areas.append((area_code, n_viol)) + worst_viol_pct = max(worst_viol_pct, mv_pct) + + # Start new line every 10 areas with count prefix + if (completed - 1) % 10 == 0: + sys.stdout.write(f"\n{completed:4d} ") + sys.stdout.write(f" {area_code.ljust(max_id_width)}") + sys.stdout.flush() + + # After every 10th area (or the last), print elapsed + if completed % 10 == 0 or completed == n: + elapsed_total = time.time() - t_start + sys.stdout.write(f" [{elapsed_total:.0f}s elapsed]") + sys.stdout.flush() + + except Exception: + completed += 1 + sys.stdout.write(f" {area}:FAIL") + sys.stdout.flush() + sys.stdout.write("\n") + + total = time.time() - t_start + print(f"\nCompleted {completed}/{n} areas in {total:.1f}s") + if violated_areas: + total_viol = sum(v for _, v in violated_areas) + print( + f"{len(violated_areas)} areas had violated targets" + f" ({total_viol} targets total)." + f" Largest violation: {worst_viol_pct:.2f}%." + ) + cmd = "python -m tmd.areas.quality_report" + print(f"Run: {cmd} for full details.") + else: + print("All targets hit within tolerance.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Batch area weight optimization" + ) + parser.add_argument( + "--scope", + type=str, + default="states", + help="'states', 'cds', 'all', or comma-separated codes", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel workers (default: 1)", + ) + parser.add_argument( + "--force", + action="store_true", + help="Recompute all areas even if up-to-date", + ) + args = parser.parse_args() + run_batch( + num_workers=args.workers, + area_filter=args.scope, + force=args.force, + ) diff --git a/tmd/areas/create_area_weights.py b/tmd/areas/create_area_weights.py index eb046bc6..f50cb35e 100644 --- a/tmd/areas/create_area_weights.py +++ b/tmd/areas/create_area_weights.py @@ -1,705 +1,606 @@ """ -Construct AREA_tmd_weights.csv.gz, a Tax-Calculator-style weights file -for FIRST_YEAR through LAST_YEAR for the specified sub-national AREA. - -AREA prefix for state areas are the two lower-case character postal codes. -AREA prefix for congressional districts are the state prefix followed by -two digits (with a leading zero) identifying the district. States with -only one congressional district have 00 as the two digits to align names -with IRS data. +Clarabel-based constrained QP area weight optimization. + +Finds weight multipliers (x) for each record such that area-specific +weighted sums match area targets within a tolerance band, while +minimizing deviation from population-proportional weights. + +Formulation: + minimize sum((x_i - 1)^2) + subject to t_j*(1-eps) <= (Bx)_j <= t_j*(1+eps) [target bounds] + x_min <= x_i <= x_max [multiplier bounds] + +where: + x_i = area_weight_i / (pop_share * national_weight_i) + pop_share = area_population / national_population + B[j,i] = (pop_share * national_weight_i) * A[j,i] + + x_i = 1 means record i gets its population-proportional share. + The optimizer adjusts x_i to hit area-specific targets. + +Elastic slack variables handle infeasibility gracefully: + minimize sum((x_i - 1)^2) + M * sum(s^2) + subject to lb_j <= (Bx)_j + s_lo - s_hi <= ub_j, s >= 0 + +Follows the same QP construction as tmd/utils/reweight.py. """ import sys -import re import time -import yaml + +import clarabel import numpy as np import pandas as pd -from scipy.sparse import csr_matrix -from scipy.optimize import minimize, Bounds -import jax -import jax.numpy as jnp -from jax.experimental.sparse import BCOO -from tmd.storage import STORAGE_FOLDER -from tmd.imputation_assumptions import TAXYEAR, POPULATION_FILE +import yaml +from scipy.sparse import ( + csc_matrix, + diags as spdiags, + eye as speye, + hstack, + vstack, +) + from tmd.areas import AREAS_FOLDER +from tmd.imputation_assumptions import POPULATION_FILE, TAXYEAR +from tmd.storage import STORAGE_FOLDER FIRST_YEAR = TAXYEAR LAST_YEAR = 2034 INFILE_PATH = STORAGE_FOLDER / "output" / "tmd.csv.gz" POPFILE_PATH = STORAGE_FOLDER / "input" / POPULATION_FILE - -# Tax-Calcultor calculated variable cache files: TAXCALC_AGI_CACHE = STORAGE_FOLDER / "output" / "cached_c00100.npy" - -PARAMS = {} - -# default target parameters: -TARGET_RATIO_TOLERANCE = 0.0040 # what is considered hitting the target -DUMP_ALL_TARGET_DEVIATIONS = False # set to True only for diagnostic work - -# default regularization parameters: -DELTA_INIT_VALUE = 1.0e-9 -DELTA_MAX_LOOPS = 1 - -# default optimization parameters: -OPTIMIZE_FTOL = 1e-9 -OPTIMIZE_GTOL = 1e-9 -OPTIMIZE_MAXITER = 5000 -OPTIMIZE_IPRINT = 0 # 20 is a good diagnostic value; set to 0 for no output -OPTIMIZE_RESULTS = False # set to True to see complete optimization results - - -def valid_area(area: str): - """ - Check validity of area string returning a boolean value. - """ - # Census on which Congressional districts are based: - # : cd_census_year = 2010 implies districts are for the 117th Congress - # : cd_census_year = 2020 implies districts are for the 118th Congress - cd_census_year = 2020 - # data in the state_info dictionary is taken from the following document: - # 2020 Census Apportionment Results, April 26, 2021, - # Table C1. Number of Seats in - # U.S. House of Representatives by State: 1910 to 2020 - # https://www.census.gov/data/tables/2020/dec/2020-apportionment-data.html - state_info = { - # number of Congressional districts per state indexed by cd_census_year - "AL": {2020: 7, 2010: 7}, - "AK": {2020: 1, 2010: 1}, - "AZ": {2020: 9, 2010: 9}, - "AR": {2020: 4, 2010: 4}, - "CA": {2020: 52, 2010: 53}, - "CO": {2020: 8, 2010: 7}, - "CT": {2020: 5, 2010: 5}, - "DE": {2020: 1, 2010: 1}, - "DC": {2020: 0, 2010: 0}, - "FL": {2020: 28, 2010: 27}, - "GA": {2020: 14, 2010: 14}, - "HI": {2020: 2, 2010: 2}, - "ID": {2020: 2, 2010: 2}, - "IL": {2020: 17, 2010: 18}, - "IN": {2020: 9, 2010: 9}, - "IA": {2020: 4, 2010: 4}, - "KS": {2020: 4, 2010: 4}, - "KY": {2020: 6, 2010: 6}, - "LA": {2020: 6, 2010: 6}, - "ME": {2020: 2, 2010: 2}, - "MD": {2020: 8, 2010: 8}, - "MA": {2020: 9, 2010: 9}, - "MI": {2020: 13, 2010: 14}, - "MN": {2020: 8, 2010: 8}, - "MS": {2020: 4, 2010: 4}, - "MO": {2020: 8, 2010: 8}, - "MT": {2020: 2, 2010: 1}, - "NE": {2020: 3, 2010: 3}, - "NV": {2020: 4, 2010: 4}, - "NH": {2020: 2, 2010: 2}, - "NJ": {2020: 12, 2010: 12}, - "NM": {2020: 3, 2010: 3}, - "NY": {2020: 26, 2010: 27}, - "NC": {2020: 14, 2010: 13}, - "ND": {2020: 1, 2010: 1}, - "OH": {2020: 15, 2010: 16}, - "OK": {2020: 5, 2010: 5}, - "OR": {2020: 6, 2010: 5}, - "PA": {2020: 17, 2010: 18}, - "RI": {2020: 2, 2010: 2}, - "SC": {2020: 7, 2010: 7}, - "SD": {2020: 1, 2010: 1}, - "TN": {2020: 9, 2010: 9}, - "TX": {2020: 38, 2010: 36}, - "UT": {2020: 4, 2010: 4}, - "VT": {2020: 1, 2010: 1}, - "VA": {2020: 11, 2010: 11}, - "WA": {2020: 10, 2010: 10}, - "WV": {2020: 2, 2010: 3}, - "WI": {2020: 8, 2010: 8}, - "WY": {2020: 1, 2010: 1}, - } - # check state_info validity - assert len(state_info) == 50 + 1 - total = {2010: 0, 2020: 0} - for _, seats in state_info.items(): - total[2010] += seats[2010] - total[2020] += seats[2020] - assert total[2010] == 435 - assert total[2020] == 435 - compare_new_vs_old = False - if compare_new_vs_old: - text = "state,2010cds,2020cds" - for state, seats in state_info.items(): - if seats[2020] != seats[2010]: - print(f"{text}= {state} {seats[2010]:2d} {seats[2020]:2d}") - sys.exit(1) - # conduct series of validity checks on specified area string - # ... check that specified area string has expected length - len_area_str = len(area) - if not 2 <= len_area_str <= 5: - sys.stderr.write(f": area '{area}' length is not in [2,5] range\n") - return False - # ... check first two characters of area string - s_c = area[0:2] - if not re.match(r"[a-z][a-z]", s_c): - emsg = "begin with two lower-case letters" - sys.stderr.write(f": area '{area}' must {emsg}\n") - return False - is_faux_area = re.match(r"[x-z][a-z]", s_c) is not None - if not is_faux_area: - if s_c.upper() not in state_info: - sys.stderr.write(f": state '{s_c}' is unknown\n") - return False - # ... check state area assumption letter which is optional - if len_area_str == 3: - assump_char = area[2:3] - if not re.match(r"[A-Z]", assump_char): - emsg = "assumption character that is not an upper-case letter" - sys.stderr.write(f": area '{area}' has {emsg}\n") - return False - if len_area_str <= 3: - return True - # ... check Congressional district area string - if not re.match(r"\d\d", area[2:4]): - emsg = "have two numbers after the state code" - sys.stderr.write(f": area '{area}' must {emsg}\n") - return False - if is_faux_area: - max_cdnum = 99 - else: - max_cdnum = state_info[s_c.upper()][cd_census_year] - cdnum = int(area[2:4]) - if max_cdnum <= 1: - if cdnum != 0: - sys.stderr.write( - f": use area '{s_c}00' for this one-district state\n" - ) - return False - else: # if max_cdnum >= 2 - if cdnum > max_cdnum: - sys.stderr.write(f": cd number '{cdnum}' exceeds {max_cdnum}\n") - return False - # ... check district area assumption character which is optional - if len_area_str == 5: - assump_char = area[4:5] - if not re.match(r"[A-Z]", assump_char): - emsg = "assumption character that is not an upper-case letter" - sys.stderr.write(f": area '{area}' has {emsg}\n") - return False - return True - - -def all_taxcalc_variables(): - """ - Return all read and needed calc Tax-Calculator variables in pd.DataFrame. - """ +CACHED_ALLVARS_PATH = STORAGE_FOLDER / "output" / "cached_allvars.csv" + +# Tax-Calculator output variables to load from cached_allvars for targeting +CACHED_TC_OUTPUTS = [ + "c18300", + "c04470", + "c02500", + "c19200", + "c19700", + "eitc", + "ctc_total", +] + +# Default solver parameters +AREA_CONSTRAINT_TOL = 0.005 +AREA_SLACK_PENALTY = 1e6 +AREA_MAX_ITER = 2000 +AREA_MULTIPLIER_MIN = 0.0 +AREA_MULTIPLIER_MAX = 25.0 + +# Default target/weight directories for states +STATE_TARGET_DIR = AREAS_FOLDER / "targets" / "states" +STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states" + + +def _load_taxcalc_data(): + """Load TMD data with cached AGI and selected Tax-Calculator outputs.""" vdf = pd.read_csv(INFILE_PATH) vdf["c00100"] = np.load(TAXCALC_AGI_CACHE) + if CACHED_ALLVARS_PATH.exists(): + allvars = pd.read_csv(CACHED_ALLVARS_PATH, usecols=CACHED_TC_OUTPUTS) + for col in CACHED_TC_OUTPUTS: + if col in allvars.columns: + vdf[col] = allvars[col].values + # Synthetic combined variable for net capital gains targeting + if "p22250" in vdf.columns and "p23250" in vdf.columns: + vdf["capgains_net"] = vdf["p22250"] + vdf["p23250"] assert np.all(vdf.s006 > 0), "Not all weights are positive" return vdf -def prepared_data(area: str, vardf: pd.DataFrame): +def _read_params(area, out, target_dir=None): + """Read optional area-specific parameters YAML file.""" + if target_dir is None: + target_dir = STATE_TARGET_DIR + params = {} + pfile = f"{area}_params.yaml" + params_path = target_dir / pfile + if params_path.exists(): + with open(params_path, "r", encoding="utf-8") as f: + params = yaml.safe_load(f.read()) + exp_params = [ + "target_ratio_tolerance", + "dump_all_target_deviations", + "delta_init_value", + "delta_max_loops", + "iprint", + "constraint_tol", + "slack_penalty", + "max_iter", + "multiplier_min", + "multiplier_max", + ] + act_params = list(params.keys()) + all_ok = len(set(act_params)) == len(act_params) + for param in act_params: + if param not in exp_params: + all_ok = False + out.write( + f"WARNING: {pfile} parameter" f" {param} is not expected\n" + ) + if not all_ok: + out.write(f"IGNORING CONTENTS OF {pfile}\n") + params = {} + elif params: + out.write(f"USING CUSTOMIZED PARAMETERS IN {pfile}\n") + return params + + +def _build_constraint_matrix(area, vardf, out, target_dir=None): """ - Construct numpy 2-D target matrix and 1-D target array for - specified area using specified vardf. Also, compute initial - weights scaling factor for specified area. Return all three - as a tuple. + Build constraint matrix B and target array from area targets CSV. + + Returns (B_csc, targets, target_labels, pop_share) where: + - B_csc is a sparse matrix (n_targets x n_records) + - targets is 1-D array of target values + - target_labels is a list of descriptive strings + - pop_share is the area's population share of national population """ + if target_dir is None: + target_dir = STATE_TARGET_DIR national_population = (vardf.s006 * vardf.XTOT).sum() numobs = len(vardf) - targets_file = AREAS_FOLDER / "targets" / f"{area}_targets.csv" + targets_file = target_dir / f"{area}_targets.csv" tdf = pd.read_csv(targets_file, comment="#") - tm_tuple = () - ta_list = [] - row_num = 1 - initial_weights_scale = None - for row in tdf.itertuples(index=False): - row_num += 1 - line = f"{area}:L{row_num}" - # construct target amount for this row - unscaled_target = row.target - if unscaled_target == 0: - unscaled_target = 1.0 - scale = 1.0 / unscaled_target - scaled_target = unscaled_target * scale - ta_list.append(scaled_target) - # confirm that row_num 2 contains the area population target - if row_num == 2: - bool_list = [ - row.varname == "XTOT", - row.count == 0, - row.scope == 0, - row.agilo < -8e99, - row.agihi > 8e99, - row.fstatus == 0, - ] - assert all( - bool_list - ), f"{line} does not contain the area population target" - initial_weights_scale = row.target / national_population + + targets_list = [] + labels_list = [] + columns = [] + pop_share = None + + for row_idx, row in enumerate(tdf.itertuples(index=False)): + line = f"{area}:target{row_idx + 1}" + + # extract target value + target_val = row.target + targets_list.append(target_val) + + # build label + label = ( + f"{row.varname}" + f"/cnt={row.count}" + f"/scope={row.scope}" + f"/agi=[{row.agilo},{row.agihi})" + f"/fs={row.fstatus}" + ) + labels_list.append(label) + + # first row must be XTOT population target + if row_idx == 0: + assert row.varname == "XTOT", f"{line}: first target must be XTOT" + assert row.count == 0 and row.scope == 0 + assert row.agilo < -8e99 and row.agihi > 8e99 + assert row.fstatus == 0 + pop_share = row.target / national_population + out.write( + f"pop_share = {row.target:.0f}" + f" / {national_population:.0f}" + f" = {pop_share:.6f}\n" + ) + # construct variable array for this target - assert ( - row.count >= 0 and row.count <= 4 - ), f"count value {row.count} not in [0,4] range on {line}" - if row.count == 0: # tabulate $ variable amount - unmasked_varray = vardf[row.varname].astype(float) - elif row.count == 1: # count units with any variable amount - unmasked_varray = (vardf[row.varname] > -np.inf).astype(float) - elif row.count == 2: # count only units with non-zero variable amount - unmasked_varray = (vardf[row.varname] != 0).astype(float) - elif row.count == 3: # count only units with positive variable amount - unmasked_varray = (vardf[row.varname] > 0).astype(float) - else: # count only units with negative variable amount (row.count==4) - unmasked_varray = (vardf[row.varname] < 0).astype(float) - mask = np.ones(numobs, dtype=int) - assert ( - row.scope >= 0 and row.scope <= 2 - ), f"scope value {row.scope} not in [0,2] range on {line}" + assert 0 <= row.count <= 4, f"count {row.count} not in [0,4] on {line}" + if row.count == 0: + var_array = vardf[row.varname].astype(float).values + elif row.count == 1: + var_array = np.ones(numobs, dtype=float) + elif row.count == 2: + var_array = (vardf[row.varname] != 0).astype(float).values + elif row.count == 3: + var_array = (vardf[row.varname] > 0).astype(float).values + else: + var_array = (vardf[row.varname] < 0).astype(float).values + + # construct mask + mask = np.ones(numobs, dtype=float) + assert 0 <= row.scope <= 2, f"scope {row.scope} not in [0,2] on {line}" if row.scope == 1: - mask *= vardf.data_source == 1 # PUF records + mask *= (vardf.data_source == 1).values.astype(float) elif row.scope == 2: - mask *= vardf.data_source == 0 # CPS records + mask *= (vardf.data_source == 0).values.astype(float) + in_agi_bin = (vardf.c00100 >= row.agilo) & (vardf.c00100 < row.agihi) - mask *= in_agi_bin + mask *= in_agi_bin.values.astype(float) + assert ( - row.fstatus >= 0 and row.fstatus <= 5 - ), f"fstatus value {row.fstatus} not in [0,5] range on {line}" + 0 <= row.fstatus <= 5 + ), f"fstatus {row.fstatus} not in [0,5] on {line}" if row.fstatus > 0: - mask *= vardf.MARS == row.fstatus - scaled_masked_varray = mask * unmasked_varray * scale - tm_tuple = tm_tuple + (scaled_masked_varray,) - # construct target matrix and target array and return as tuple - scale_factor = 1.0 # as high as 1e9 works just fine - target_matrix = np.vstack(tm_tuple).T * scale_factor - target_array = np.array(ta_list) * scale_factor - return ( - target_matrix, - target_array, - initial_weights_scale, - ) + mask *= (vardf.MARS == row.fstatus).values.astype(float) + # A[j,i] = mask * var_array (data values for this constraint) + columns.append(mask * var_array) -def target_misses(wght, target_matrix, target_array): - """ - Return number of target misses for the specified weight array and a - string containing size of each actual/expect target miss as a tuple. - """ - actual = np.dot(wght, target_matrix) - tratio = actual / target_array - tol = PARAMS.get("target_ratio_tolerance", TARGET_RATIO_TOLERANCE) - lob = 1.0 - tol - hib = 1.0 + tol - num = ((tratio < lob) | (tratio >= hib)).sum() - mstr = "" - if num > 0: - for tnum, ratio in enumerate(tratio): - if ratio < lob or ratio >= hib: - mstr += ( - f" ::::TARGET{(tnum + 1):03d}:ACT/EXP,lob,hib=" - f" {ratio:.6f} {lob:.6f} {hib:.6f}\n" - ) - return (num, mstr) + assert pop_share is not None, "XTOT target not found" + # A matrix: n_records x n_targets (each column is a constraint) + A_dense = np.column_stack(columns) -def target_rmse(wght, target_matrix, target_array, out, delta=None): - """ - Return RMSE of the target deviations given specified arguments. - """ - act = np.dot(wght, target_matrix) - act_minus_exp = act - target_array - ratio = act / target_array - min_ratio = np.min(ratio) - max_ratio = np.max(ratio) - dump = PARAMS.get("dump_all_target_deviations", DUMP_ALL_TARGET_DEVIATIONS) - if dump: - for tnum, ratio_ in enumerate(ratio): + # B = diag(w0) @ A where w0 = pop_share * s006 + w0 = pop_share * vardf.s006.values + B_dense = w0[:, np.newaxis] * A_dense + + # convert to sparse (n_targets x n_records) for Clarabel + B_csc = csc_matrix(B_dense.T) + + targets_arr = np.array(targets_list) + return B_csc, targets_arr, labels_list, pop_share + + +def _drop_impossible_targets(B_csc, targets, labels, out): + """Drop targets where all constraint matrix values are zero.""" + col_sums = np.abs(np.asarray(B_csc.sum(axis=1))).ravel() + all_zero = col_sums == 0 + if all_zero.any(): + n_drop = int(all_zero.sum()) + for i in np.where(all_zero)[0]: out.write( - f"TARGET{(tnum + 1):03d}:ACT-EXP,ACT/EXP= " - f"{act_minus_exp[tnum]:16.9e}, {ratio_:.3f}\n" + f"DROPPING impossible target" f" (all zeros): {labels[i]}\n" ) - # show distribution of target ratios - tol = PARAMS.get("target_ratio_tolerance", TARGET_RATIO_TOLERANCE) - bins = [ - -np.inf, - 0.0, - 0.4, - 0.8, - 0.9, - 0.99, - 1.0 - tol, - 1.0 + tol, - 1.01, - 1.1, - 1.2, - 1.6, - 2.0, - 3.0, - 4.0, - 5.0, - np.inf, - ] - tot = ratio.size - out.write(f"DISTRIBUTION OF TARGET ACT/EXP RATIOS (n={tot}):\n") - if delta is not None: - out.write(f" with REGULARIZATION_DELTA= {delta:e}\n") - header = ( - "low bin ratio high bin ratio" - " bin # cum # bin % cum %\n" - ) - out.write(header) - cutout = pd.cut(ratio, bins, right=False, precision=6) - count = pd.Series(cutout).value_counts().sort_index().to_dict() - cum = 0 - for interval, num in count.items(): - cum += num - if cum == 0: - continue - line = ( - f">={interval.left:13.6f}, <{interval.right:13.6f}:" - f" {num:6d} {cum:6d} {num / tot:7.2%} {cum / tot:7.2%}\n" + keep = ~all_zero + B_dense = B_csc.toarray() + B_csc = csc_matrix(B_dense[keep, :]) + targets = targets[keep] + labels = [lab for lab, k in zip(labels, keep) if k] + out.write( + f"Dropped {n_drop} impossible targets," + f" {len(targets)} remaining\n" ) - out.write(line) - if cum == tot: - break - # write minimum and maximum ratio values - line = f"MINIMUM VALUE OF TARGET ACT/EXP RATIO = {min_ratio:.3f}\n" - out.write(line) - line = f"MAXIMUM VALUE OF TARGET ACT/EXP RATIO = {max_ratio:.3f}\n" - out.write(line) - # return RMSE of ACT-EXP targets - return np.sqrt(np.mean(np.square(act_minus_exp))) - - -def objective_function(x, *args): + return B_csc, targets, labels + + +def _solve_area_qp( # pylint: disable=unused-argument + B_csc, + targets, + labels, + n_records, + constraint_tol=AREA_CONSTRAINT_TOL, + slack_penalty=AREA_SLACK_PENALTY, + max_iter=AREA_MAX_ITER, + multiplier_min=AREA_MULTIPLIER_MIN, + multiplier_max=AREA_MULTIPLIER_MAX, + weight_penalty=1.0, + out=None, +): """ - Objective function for minimization. - Search for NOTE in this file for methodological details. - https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf#page=320 + Solve the area reweighting QP using Clarabel. + + Parameters + ---------- + weight_penalty : float + Penalty weight on (x-1)^2 relative to constraint + slack. Higher values keep multipliers closer to 1.0 + at the cost of more target violations. + + Returns (x_opt, s_lo, s_hi, info_dict). """ - A, b, delta = args # A is a jax sparse matrix - ssq_target_deviations = jnp.sum(jnp.square(A @ x - b)) - ssq_weight_deviations = jnp.sum(jnp.square(x - 1.0)) - return ssq_target_deviations + delta * ssq_weight_deviations + if out is None: + out = sys.stdout + + m = len(targets) + n_total = n_records + 2 * m # x + s_lo + s_hi + + # constraint bounds: t*(1-eps) <= Bx <= t*(1+eps) + abs_targets = np.abs(targets) + tol_band = abs_targets * constraint_tol + cl = targets - tol_band + cu = targets + tol_band + + # diagonal Hessian: 2*alpha for x, 2*M for slacks + hess_diag = np.empty(n_total) + hess_diag[:n_records] = 2.0 * weight_penalty + hess_diag[n_records:] = 2.0 * slack_penalty + P = spdiags(hess_diag, format="csc") + + # linear term: -2*alpha for x + q = np.zeros(n_total) + q[:n_records] = -2.0 * weight_penalty + + # extended constraint matrix: [B | I_m | -I_m] + I_m = speye(m, format="csc") + A_full = hstack([B_csc, I_m, -I_m], format="csc") + + # constraint scaling for numerical conditioning + target_scale = np.maximum(np.abs(targets), 1.0) + D_inv = spdiags(1.0 / target_scale) + A_scaled = (D_inv @ A_full).tocsc() + cl_scaled = cl / target_scale + cu_scaled = cu / target_scale + + # variable bounds + var_lb = np.empty(n_total) + var_ub = np.empty(n_total) + var_lb[:n_records] = multiplier_min + var_ub[:n_records] = multiplier_max + var_lb[n_records:] = 0.0 + var_ub[n_records:] = 1e20 + + # Clarabel form: Ax + s = b, s in NonnegativeCone + I_n = speye(n_total, format="csc") + A_clar = vstack( + [A_scaled, -A_scaled, I_n, -I_n], + format="csc", + ) + b_clar = np.concatenate([cu_scaled, -cl_scaled, var_ub, -var_lb]) + + m_constraints = len(b_clar) + # pylint: disable=no-member + cones = [clarabel.NonnegativeConeT(m_constraints)] + + settings = clarabel.DefaultSettings() + # pylint: enable=no-member + settings.verbose = False + settings.max_iter = max_iter + settings.tol_gap_abs = 1e-7 + settings.tol_gap_rel = 1e-7 + settings.tol_feas = 1e-7 + + # solve + out.write("STARTING CLARABEL SOLVER...\n") + t_start = time.time() + solver = clarabel.DefaultSolver( # pylint: disable=no-member + P, q, A_clar, b_clar, cones, settings + ) + result = solver.solve() + elapsed = time.time() - t_start + status_str = str(result.status) + out.write( + f"Solver status: {status_str}\n" + f"Iterations: {result.iterations}\n" + f"Solve time: {elapsed:.2f}s\n" + ) -JIT_FVAL_AND_GRAD = jax.jit(jax.value_and_grad(objective_function)) + # extract solution + y_opt = np.array(result.x) + x_opt = y_opt[:n_records] + s_lo = y_opt[n_records : n_records + m] + s_hi = y_opt[n_records + m :] + x_opt = np.clip(x_opt, multiplier_min, multiplier_max) -def weight_ratio_distribution(ratio, delta, out): - """ - Print distribution of post-optimized to pre-optimized weight ratios. - """ + info = { + "status": status_str, + "iterations": result.iterations, + "solve_time": elapsed, + "clarabel_solve_time": result.solve_time, + } + + return x_opt, s_lo, s_hi, info + + +def _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out +): + """Print target accuracy diagnostics.""" + achieved = np.asarray(B_csc @ x_opt).ravel() + abs_errors = np.abs(achieved - targets) + rel_errors = abs_errors / np.maximum(np.abs(targets), 1.0) + + out.write(f"TARGET ACCURACY ({len(targets)} targets):\n") + out.write(f" mean |relative error|: {rel_errors.mean():.6f}\n") + out.write(f" max |relative error|: {rel_errors.max():.6f}\n") + + eps = 1e-9 + n_violated = int((rel_errors > constraint_tol + eps).sum()) + n_hit = len(targets) - n_violated + out.write( + f" targets hit: {n_hit}/{len(targets)}" + f" (tolerance: +/-{constraint_tol * 100:.1f}% + eps)\n" + ) + if n_violated > 0: + out.write(f" VIOLATED: {n_violated} targets\n") + worst_idx = np.argsort(rel_errors)[::-1] + for idx in worst_idx[: min(10, n_violated)]: + out.write( + f" {rel_errors[idx] * 100:7.3f}%" + f" | target={targets[idx]:15.0f}" + f" | achieved={achieved[idx]:15.0f}" + f" | {labels[idx]}\n" + ) + return n_violated + + +def _print_multiplier_diagnostics(x_opt, out): + """Print weight multiplier distribution diagnostics.""" + out.write("MULTIPLIER DISTRIBUTION:\n") + out.write( + f" min={x_opt.min():.6f}," + f" p5={np.percentile(x_opt, 5):.6f}," + f" median={np.median(x_opt):.6f}," + f" p95={np.percentile(x_opt, 95):.6f}," + f" max={x_opt.max():.6f}\n" + ) + out.write( + f" RMSE from 1.0:" f" {np.sqrt(np.mean((x_opt - 1.0) ** 2)):.6f}\n" + ) + + # distribution bins bins = [ 0.0, 1e-6, 0.1, - 0.2, 0.5, 0.8, - 0.85, 0.9, 0.95, 1.0, 1.05, 1.1, - 1.15, 1.2, + 1.5, 2.0, 5.0, - 1e1, - 1e2, - 1e3, - 1e4, - 1e5, + 10.0, + 100.0, np.inf, ] - tot = ratio.size - out.write(f"DISTRIBUTION OF AREA/US WEIGHT RATIO (n={tot}):\n") - out.write(f" with REGULARIZATION_DELTA= {delta:e}\n") - header = ( - "low bin ratio high bin ratio" - " bin # cum # bin % cum %\n" - ) - out.write(header) - cutout = pd.cut(ratio, bins, right=False, precision=6) - count = pd.Series(cutout).value_counts().sort_index().to_dict() - cum = 0 - for interval, num in count.items(): - cum += num - if cum == 0: - continue - line = ( - f">={interval.left:13.6f}, <{interval.right:13.6f}:" - f" {num:6d} {cum:6d} {num / tot:7.2%} {cum / tot:7.2%}\n" - ) - out.write(line) - if cum == tot: - break - # write RMSE of area/us weight ratio deviations from one - rmse = np.sqrt(np.mean(np.square(ratio - 1.0))) - line = f"RMSE OF AREA/US WEIGHT RATIO DEVIATIONS FROM ONE = {rmse:e}\n" - out.write(line) + tot = len(x_opt) + out.write(f" distribution (n={tot}):\n") + for i in range(len(bins) - 1): + count = int(((x_opt >= bins[i]) & (x_opt < bins[i + 1])).sum()) + if count > 0: + out.write( + f" [{bins[i]:10.4f}," + f" {bins[i + 1]:10.4f}):" + f" {count:7d}" + f" ({count / tot:7.2%})\n" + ) -# -- High-level logic of the script: +def _print_slack_diagnostics(s_lo, s_hi, targets, labels, out): + """Print elastic slack diagnostics.""" + total_slack = s_lo + s_hi + n_active = int(np.sum(total_slack > 1e-6)) + if n_active > 0: + out.write( + f"ELASTIC SLACK active on" + f" {n_active}/{len(targets)} constraints:\n" + ) + slack_idx = np.where(total_slack > 1e-6)[0] + for idx in slack_idx[np.argsort(total_slack[slack_idx])[::-1]][:20]: + out.write( + f" slack={total_slack[idx]:12.2f}" + f" | target={targets[idx]:15.0f}" + f" | {labels[idx]}\n" + ) + else: + out.write("ALL CONSTRAINTS SATISFIED WITHOUT SLACK\n") def create_area_weights_file( - area: str, - write_log: bool = True, - write_file: bool = True, + area, + write_log=True, + write_file=True, + target_dir=None, + weight_dir=None, ): """ - Create Tax-Calculator-style weights file for FIRST_YEAR through LAST_YEAR - for specified area using information in area targets CSV file. - Write log file if write_log=True, otherwise log is written to stdout. - Write weights file if write_file=True, otherwise just do calculations. + Create area weights file using Clarabel constrained QP solver. + + Returns 0 on success. """ - # remove any existing log or weights files - awpath = AREAS_FOLDER / "weights" / f"{area}_tmd_weights.csv.gz" + if target_dir is None: + target_dir = STATE_TARGET_DIR + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + # ensure output directory exists + weight_dir.mkdir(parents=True, exist_ok=True) + + # remove any existing output files + awpath = weight_dir / f"{area}_tmd_weights.csv.gz" awpath.unlink(missing_ok=True) - logpath = AREAS_FOLDER / "weights" / f"{area}.log" + logpath = weight_dir / f"{area}.log" logpath.unlink(missing_ok=True) - # specify log output device + # set up output if write_log: out = open( # pylint: disable=consider-using-with - logpath, - "w", - encoding="utf-8", + logpath, "w", encoding="utf-8" ) else: out = sys.stdout + if write_file: - out.write(f"CREATING WEIGHTS FILE FOR AREA {area} ...\n") + out.write( + f"CREATING WEIGHTS FILE FOR AREA {area}" " (Clarabel solver) ...\n" + ) else: - out.write(f"DOING JUST WEIGHTS FILE CALCS FOR AREA {area} ...\n") + out.write( + f"DOING JUST CALCS FOR AREA {area}" " (Clarabel solver) ...\n" + ) - # configure jax library - jax.config.update("jax_platform_name", "cpu") # ignore GPU/TPU if present - jax.config.update("jax_enable_x64", True) # use double precision floats + # read optional parameters + params = _read_params(area, out, target_dir=target_dir) + constraint_tol = params.get( + "constraint_tol", + params.get("target_ratio_tolerance", AREA_CONSTRAINT_TOL), + ) + slack_penalty = params.get("slack_penalty", AREA_SLACK_PENALTY) + max_iter = params.get("max_iter", AREA_MAX_ITER) + multiplier_min = params.get("multiplier_min", AREA_MULTIPLIER_MIN) + multiplier_max = params.get("multiplier_max", AREA_MULTIPLIER_MAX) + + # load data and build constraint matrix + vdf = _load_taxcalc_data() + out.write(f"Loaded {len(vdf)} records\n") + out.write(f"National weight sum: {vdf.s006.sum():.0f}\n") + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=target_dir + ) + out.write( + f"Built constraint matrix:" + f" {B_csc.shape[0]} targets x" + f" {B_csc.shape[1]} records\n" + ) + out.write(f"Constraint tolerance: +/-{constraint_tol * 100:.1f}%\n") + out.write(f"Multiplier bounds: [{multiplier_min}, {multiplier_max}]\n") - # read optional parameters file - global PARAMS # pylint: disable=global-statement - PARAMS = {} - pfile = f"{area}_params.yaml" - params_file = AREAS_FOLDER / "targets" / pfile - if params_file.exists(): - with open(params_file, "r", encoding="utf-8") as paramfile: - PARAMS = yaml.safe_load(paramfile.read()) - exp_params = [ - "target_ratio_tolerance", - "iprint", - "dump_all_target_deviations", - "delta_init_value", - "delta_max_loops", - ] - if len(PARAMS) > len(exp_params): - nump = len(exp_params) - out.write( - f"ERROR: {pfile} must contain no more than {nump} parameters\n" - f"IGNORING CONTENTS OF {pfile}\n" - ) - PARAMS = {} - else: - act_params = list(PARAMS.keys()) - all_ok = True - if len(set(act_params)) != len(act_params): - all_ok = False - out.write(f"ERROR: {pfile} contains duplicate parameter\n") - for param in act_params: - if param not in exp_params: - all_ok = False - out.write( - f"ERROR: {pfile} parameter {param} is not expected\n" - ) - if not all_ok: - out.write(f"IGNORING CONTENTS OF {pfile}\n") - PARAMS = {} - if PARAMS: - out.write(f"USING CUSTOMIZED PARAMETERS IN {pfile}\n") + # drop impossible targets + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) - # construct variable matrix and target array and weights_scale - vdf = all_taxcalc_variables() - target_matrix, target_array, weights_scale = prepared_data(area, vdf) - wght_us = np.array(vdf.s006 * weights_scale) - out.write("INITIAL WEIGHTS STATISTICS:\n") - out.write(f"sum of national weights = {vdf.s006.sum():e}\n") - out.write(f"area weights_scale = {weights_scale:e}\n") - num_weights = len(wght_us) - num_targets = len(target_array) - out.write(f"USING {area}_targets.csv FILE WITH {num_targets} TARGETS\n") - tolstr = "ASSUMING TARGET_RATIO_TOLERANCE" - tol = PARAMS.get("target_ratio_tolerance", TARGET_RATIO_TOLERANCE) - out.write(f"{tolstr} = {tol:.6f}\n") - rmse = target_rmse(wght_us, target_matrix, target_array, out) - out.write(f"US_PROPORTIONALLY_SCALED_TARGET_RMSE= {rmse:.9e}\n") - density = np.count_nonzero(target_matrix) / target_matrix.size - out.write(f"target_matrix sparsity ratio = {(1.0 - density):.3f}\n") - - # optimize weight ratios by minimizing the sum of squared deviations - # of area-to-us weight ratios from one such that the optimized ratios - # hit all of the area targets - # - # NOTE: This is a bi-criterion minimization problem that can be - # solved using regularization methods. For background, - # consult Stephen Boyd and Lieven Vandenberghe, Convex - # Optimization, Cambridge University Press, 2004, in - # particular equation (6.9) on page 306 (see LINK below). - # Our problem is exactly the same as (6.9) except that - # we measure x deviations from one rather than from zero. - # LINK: https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf#page=320 - # - A_dense = (target_matrix * wght_us[:, np.newaxis]).T - A = BCOO.from_scipy_sparse(csr_matrix(A_dense)) # A is JAX sparse matrix - b = target_array - delta = PARAMS.get("delta_init_values", DELTA_INIT_VALUE) - max_loop = PARAMS.get("delta_max_loops", DELTA_MAX_LOOPS) - if max_loop > 1: - delta_loop_decrement = delta / (max_loop - 1) - else: - delta_loop_decrement = delta - out.write( - "OPTIMIZE WEIGHT RATIOS POSSIBLY IN A REGULARIZATION LOOP\n" - f" where initial REGULARIZATION DELTA value is {delta:e}\n" + # check what x=1 (population-proportional) achieves + n_records = B_csc.shape[1] + x_ones = np.ones(n_records) + achieved_x1 = np.asarray(B_csc @ x_ones).ravel() + rel_err_x1 = np.abs(achieved_x1 - targets) / np.maximum( + np.abs(targets), 1.0 ) - if max_loop > 1: - out.write(f" and there are at most {max_loop} REGULARIZATION LOOPS\n") - else: - out.write(" and there is only one REGULARIZATION LOOP\n") - out.write(f" and where target_matrix.shape= {target_matrix.shape}\n") - # ... specify possibly customized value of iprint for diagnostic callback - if write_log: - iprint = OPTIMIZE_IPRINT - else: - iprint = PARAMS.get("iprint", OPTIMIZE_IPRINT) - # ... define callback for diagnostic output to replace deprecated iprint - _iter_count = 0 - - def _diagnostic_callback(intermediate_result): - nonlocal _iter_count - _iter_count += 1 - if iprint > 0 and _iter_count % iprint == 0: - fval = intermediate_result.fun - out.write(f" iter {_iter_count}: fval={fval:.6e}\n") - - # ... reduce value of regularization delta if not all targets are hit - loop = 1 - delta = DELTA_INIT_VALUE - wght0 = np.ones(num_weights) - while loop <= max_loop: - time0 = time.time() - _iter_count = 0 - res = minimize( - fun=JIT_FVAL_AND_GRAD, # objective function and its gradient - x0=wght0, # initial guess for weight ratios - jac=True, # use gradient from JIT_FVAL_AND_GRAD function - args=(A, b, delta), # fixed arguments of objective function - method="L-BFGS-B", # use L-BFGS-B algorithm - bounds=Bounds(0.0, np.inf), # consider only non-negative weights - options={ - "maxiter": OPTIMIZE_MAXITER, - "ftol": OPTIMIZE_FTOL, - "gtol": OPTIMIZE_GTOL, - }, - callback=_diagnostic_callback if iprint > 0 else None, - ) - time1 = time.time() - wght_area = res.x * wght_us - misses, minfo = target_misses(wght_area, target_matrix, target_array) - if write_log: - out.write( - f" ::loop,delta,misses: {loop}" f" {delta:e} {misses}\n" - ) - else: - out.write( - f" ::loop,delta,misses,exectime(secs): {loop}" - f" {delta:e} {misses} {(time1 - time0):.1f}\n" - ) - if misses == 0 or res.success is False: - break # out of regularization delta loop - # show magnitude of target misses - out.write(minfo) - # prepare for next regularization delta loop - loop += 1 - delta -= delta_loop_decrement - if delta < 1e-20: - delta = 0.0 - # ... show regularization/optimization results - if write_log: - res_summary = ( - f">>> final delta loop" - f" iterations={res.nit} success={res.success}\n" - f">>> message: {res.message}\n" - f">>> L-BFGS-B optimized objective function value: {res.fun:.9e}\n" - ) - else: - res_summary = ( - f">>> final delta loop exectime= {(time1 - time0):.1f} secs" - f" iterations={res.nit} success={res.success}\n" - f">>> message: {res.message}\n" - f">>> L-BFGS-B optimized objective function value: {res.fun:.9e}\n" - ) - out.write(res_summary) - if OPTIMIZE_RESULTS: - out.write(">>> full final delta loop optimization results:\n") - for key in res.keys(): - out.write(f" {key}: {res.get(key)}\n") - wght_area = res.x * wght_us - misses, minfo = target_misses(wght_area, target_matrix, target_array) - out.write(f"AREA-OPTIMIZED_TARGET_MISSES= {misses}\n") - if misses > 0: - out.write(minfo) - rmse = target_rmse(wght_area, target_matrix, target_array, out, delta) - out.write(f"AREA-OPTIMIZED_TARGET_RMSE= {rmse:.9e}\n") - weight_ratio_distribution(res.x, delta, out) + out.write("BEFORE OPTIMIZATION (x=1, population-proportional):\n") + out.write(f" mean |relative error|: {rel_err_x1.mean():.6f}\n") + out.write(f" max |relative error|: {rel_err_x1.max():.6f}\n") + + # solve QP + x_opt, s_lo, s_hi, _info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=constraint_tol, + slack_penalty=slack_penalty, + max_iter=max_iter, + multiplier_min=multiplier_min, + multiplier_max=multiplier_max, + out=out, + ) + + # diagnostics + _n_violated = _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out + ) + _print_multiplier_diagnostics(x_opt, out) + _print_slack_diagnostics(s_lo, s_hi, targets, labels, out) if write_log: out.close() if not write_file: return 0 - # write area weights file extrapolating using national population forecast - # ... get population forecast - with open(POPFILE_PATH, "r", encoding="utf-8") as pfile: - pop = yaml.safe_load(pfile.read()) - # ... set FIRST_YEAR weights - weights = wght_area - # ... construct dictionary of scaled-up weights by year - wdict = {f"WT{FIRST_YEAR}": weights} + # write area weights file with population extrapolation + w0 = pop_share * vdf.s006.values + wght_area = x_opt * w0 + + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + pop = yaml.safe_load(pf.read()) + + wdict = {f"WT{FIRST_YEAR}": wght_area} cum_pop_growth = 1.0 for year in range(FIRST_YEAR + 1, LAST_YEAR + 1): annual_pop_growth = pop[year] / pop[year - 1] cum_pop_growth *= annual_pop_growth - wght = weights.copy() * cum_pop_growth - wdict[f"WT{year}"] = wght - # ... write weights to CSV-formatted file + wdict[f"WT{year}"] = wght_area * cum_pop_growth + wdf = pd.DataFrame.from_dict(wdict) - wdf.to_csv(awpath, index=False, float_format="%.5f", compression="gzip") + wdf.to_csv( + awpath, + index=False, + float_format="%.5f", + compression="gzip", + ) return 0 - - -if __name__ == "__main__": - if len(sys.argv) != 2: - sys.stderr.write( - "ERROR: exactly one command-line argument is required\n" - ) - sys.exit(1) - area_code = sys.argv[1] - if not valid_area(area_code): - sys.stderr.write(f"ERROR: {area_code} is not valid\n") - sys.exit(1) - tfile = f"{area_code}_targets.csv" - target_file = AREAS_FOLDER / "targets" / tfile - if not target_file.exists(): - sys.stderr.write( - f"ERROR: {tfile} file not in tmd/areas/targets folder\n" - ) - sys.exit(1) - RCODE = create_area_weights_file( - area_code, - write_log=False, - write_file=True, - ) - sys.exit(RCODE) diff --git a/tmd/areas/make_all.py b/tmd/areas/make_all.py index 362b4ab9..a6fe4d7b 100644 --- a/tmd/areas/make_all.py +++ b/tmd/areas/make_all.py @@ -9,7 +9,6 @@ import time from multiprocessing import Pool from tmd.areas.create_area_weights import ( - valid_area, create_area_weights_file, ) from tmd.areas import AREAS_FOLDER @@ -58,9 +57,8 @@ def to_do_areas(): tpaths = sorted(list(tfolder.glob("*_targets.csv"))) for tpath in tpaths: area = tpath.name.split("_")[0] - if not valid_area(area): - print(f"Skipping invalid area name {area}") - continue # skip this area + if area.startswith("."): + continue # skip hidden files wpath = AREAS_FOLDER / "weights" / f"{area}_tmd_weights.csv.gz" if wpath.exists(): wtime = wpath.stat().st_mtime diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py new file mode 100644 index 00000000..1531cb2e --- /dev/null +++ b/tmd/areas/quality_report.py @@ -0,0 +1,771 @@ +# pylint: disable=import-outside-toplevel,inconsistent-quotes +""" +Cross-state quality summary report. + +Parses solver logs for all states and produces a summary showing: + - Solve status and timing + - Target accuracy (hit rate, mean/max error) + - Weight distortion (RMSE, percentiles) + - Violated targets by variable + - Weight exhaustion and cross-state aggregation diagnostics + +Usage: + python -m tmd.areas.quality_report + python -m tmd.areas.quality_report --scope CA,WY +""" + +import argparse +import re +from pathlib import Path + +import numpy as np +import pandas as pd + +from tmd.areas.create_area_weights import ( + AREA_CONSTRAINT_TOL, + STATE_WEIGHT_DIR, +) +from tmd.areas.prepare.constants import ALL_STATES +from tmd.imputation_assumptions import TAXYEAR + +_WT_COL = f"WT{TAXYEAR}" + +# Decode raw constraint descriptions into human-readable labels +_CNT_LABELS = {0: "amt", 1: "returns", 2: "nz-count"} +_FS_LABELS = {0: "all", 1: "single", 2: "MFJ", 4: "HoH"} + + +def _humanize_desc(desc: str) -> str: + """ + Turn 'c00100/cnt=1/scope=1/agi=[500000.0,1000000.0)/fs=4' + into 'c00100 returns HoH $500K-$1M'. + """ + parts = desc.split("/") + varname = parts[0] + attrs = {} + for p in parts[1:]: + if "=" in p: + k, v = p.split("=", 1) + attrs[k] = v + + cnt = int(attrs.get("cnt", -1)) + fs = int(attrs.get("fs", 0)) + agi_raw = attrs.get("agi", "") + + cnt_label = _CNT_LABELS.get(cnt, f"cnt{cnt}") + fs_label = _FS_LABELS.get(fs, f"fs{fs}") + + # Parse AGI range like [500000.0,1000000.0) + agi_label = "" + m = re.match(r"\[([^,]+),([^)]+)\)", agi_raw) + if m: + lo_s, hi_s = m.group(1), m.group(2) + lo = float(lo_s) + hi = float(hi_s) + if lo < -1e10: + agi_label = f"<${hi / 1000:.0f}K" + elif hi > 1e10: + agi_label = f"${lo / 1000:.0f}K+" + else: + agi_label = f"${lo / 1000:.0f}K-${hi / 1000:.0f}K" + + pieces = [varname, cnt_label] + if fs != 0: + pieces.append(fs_label) + if agi_label: + pieces.append(agi_label) + return " ".join(pieces) + + +def parse_log(logpath: Path) -> dict: + """Parse a single area solver log file into a summary dict.""" + if not logpath.exists(): + return {"status": "NO LOG"} + log = logpath.read_text() + + result = {"status": "UNKNOWN"} + + # Solve status + m = re.search(r"Solver status: (\S+)", log) + if m: + result["status"] = m.group(1) + if "PrimalInfeasible" in log or "FAILED" in log: + result["status"] = "FAILED" + + # Solve time + m = re.search(r"Solve time: ([\d.]+)s", log) + if m: + result["solve_time"] = float(m.group(1)) + + # Target accuracy + m = re.search(r"mean \|relative error\|: ([\d.]+)", log) + if m: + result["mean_err"] = float(m.group(1)) + m = re.search(r"max \|relative error\|: ([\d.]+)", log) + if m: + result["max_err"] = float(m.group(1)) + m = re.search(r"targets hit: (\d+)/(\d+)", log) + if m: + result["targets_hit"] = int(m.group(1)) + result["targets_total"] = int(m.group(2)) + m_viol = re.search(r"VIOLATED: (\d+) targets", log) + result["n_violated"] = int(m_viol.group(1)) if m_viol else 0 + + # Weight distortion + m = re.search( + r"min=([\d.]+), p5=([\d.]+), median=([\d.]+), " + r"p95=([\d.]+), max=([\d.]+)", + log, + ) + if m: + result["w_min"] = float(m.group(1)) + result["w_p5"] = float(m.group(2)) + result["w_median"] = float(m.group(3)) + result["w_p95"] = float(m.group(4)) + result["w_max"] = float(m.group(5)) + m = re.search(r"RMSE from 1.0: ([\d.]+)", log) + if m: + result["w_rmse"] = float(m.group(1)) + + # Weight distribution histogram + dist_bins = {} + for line in log.splitlines(): + m_bin = re.match( + r"\s+\[\s*([\d.]+),\s*([\d.]+)\):\s+(\d+)\s+\(\s*([\d.]+)%\)", + line, + ) + if m_bin: + lo, hi = float(m_bin.group(1)), float(m_bin.group(2)) + cnt, pct = int(m_bin.group(3)), float(m_bin.group(4)) + dist_bins[(lo, hi)] = {"count": cnt, "pct": pct} + result["dist_bins"] = dist_bins + + m_n = re.search(r"distribution \(n=(\d+)\)", log) + if m_n: + result["n_records"] = int(m_n.group(1)) + + # Violated target details + violated = [] + in_violated = False + for line in log.splitlines(): + if "VIOLATED:" in line and "targets" in line: + in_violated = True + continue + if in_violated: + m_det = re.match( + r"\s+([\d.]+)%\s*\|\s*target=\s*([\d.]+)\s*\|" + r"\s*achieved=\s*([\d.]+)\s*\|" + r"\s*(\S+/cnt=\d+/scope=\d+/agi=.*?/fs=\d+)", + line, + ) + if m_det: + violated.append( + { + "pct_err": float(m_det.group(1)), + "target": float(m_det.group(2)), + "achieved": float(m_det.group(3)), + "desc": m_det.group(4), + } + ) + else: + in_violated = False + result["violated_details"] = violated + + return result + + +def generate_report(areas=None, weight_dir=None): + """Generate cross-state quality summary report.""" + if areas is None: + areas = ALL_STATES + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + + rows = [] + for st in areas: + logpath = weight_dir / f"{st.lower()}.log" + info = parse_log(logpath) + info["state"] = st + rows.append(info) + + df = pd.DataFrame(rows) + + # Summary statistics + solved = df[df["status"].isin(["Solved", "AlmostSolved"])] + failed = df[df["status"] == "FAILED"] + n_states = len(df) + n_solved = len(solved) + n_failed = len(failed) + n_violated_states = (solved["n_violated"] > 0).sum() + total_violated = solved["n_violated"].sum() + + tol_pct = AREA_CONSTRAINT_TOL * 100 + + lines = [] + lines.append("=" * 80) + lines.append("CROSS-STATE QUALITY SUMMARY REPORT") + lines.append("=" * 80) + lines.append("") + + # Overall + lines.append(f"States: {n_states}") + lines.append(f"Solved: {n_solved}") + lines.append(f"Failed: {n_failed}") + if n_failed > 0: + lines.append(f" Failed: {', '.join(failed['state'].tolist())}") + lines.append( + f"States with violated targets: {n_violated_states}/{n_solved}" + ) + if not solved.empty and "targets_total" in solved.columns: + tpt = int(solved["targets_total"].iloc[0]) + tpt_sum = int(solved["targets_total"].sum()) + else: + tpt, tpt_sum = "?", "?" + lines.append(f"Total targets: {n_solved} states \u00d7 {tpt} = {tpt_sum}") + lines.append(f"Total violated targets: {int(total_violated)}") + lines.append("") + + # Target accuracy + if not solved.empty and "mean_err" in solved.columns: + lines.append("TARGET ACCURACY:") + lines.append( + f" Per-state mean error: " + f"avg across states={solved['mean_err'].mean():.4f}, " + f"worst state={solved['mean_err'].max():.4f}" + ) + lines.append( + f" Per-state max error: " + f"avg across states={solved['max_err'].mean():.4f}, " + f"worst state={solved['max_err'].max():.4f}" + ) + if "targets_hit" in solved.columns: + total_t = solved["targets_total"].iloc[0] + hit_pcts = solved["targets_hit"] / solved["targets_total"] * 100 + lines.append( + f" Hit rate: " + f"avg={hit_pcts.mean():.1f}%, " + f"min={hit_pcts.min():.1f}% " + f"(out of {total_t} targets, " + f"tolerance: +/-{tol_pct:.1f}% + eps)" + ) + lines.append("") + + # Weight distortion + if not solved.empty and "w_rmse" in solved.columns: + lines.append("WEIGHT DISTORTION (multiplier from 1.0):") + lines.append( + f" RMSE: avg={solved['w_rmse'].mean():.3f}, " + f"max={solved['w_rmse'].max():.3f}" + ) + lines.append( + f" Min: avg={solved['w_min'].mean():.3f}, " + f"min={solved['w_min'].min():.3f}" + ) + lines.append( + f" P05: avg={solved['w_p5'].mean():.3f}, " + f"min={solved['w_p5'].min():.3f}" + ) + lines.append( + f" Median: avg={solved['w_median'].mean():.3f}, " + f"range=[{solved['w_median'].min():.3f}, " + f"{solved['w_median'].max():.3f}]" + ) + lines.append( + f" P95: avg={solved['w_p95'].mean():.3f}, " + f"max={solved['w_p95'].max():.3f}" + ) + lines.append( + f" Max: avg={solved['w_max'].mean():.1f}, " + f"max={solved['w_max'].max():.1f}" + ) + lines.append("") + + # Near-zero weight summary + if not solved.empty: + zero_pcts = [] + lt01_pcts = [] + for _, row in solved.iterrows(): + dist = row.get("dist_bins", {}) + n_rec = row.get("n_records", 0) + if not dist or n_rec == 0: + continue + n_zero = dist.get((0.0, 0.0), {}).get("count", 0) + n_lt01 = n_zero + dist.get((0.0, 0.1), {}).get("count", 0) + zero_pcts.append(100 * n_zero / n_rec) + lt01_pcts.append(100 * n_lt01 / n_rec) + if zero_pcts: + lines.append("NEAR-ZERO WEIGHT MULTIPLIERS (% of records):") + lines.append( + f" Exact zero (x=0): " + f"avg={np.mean(zero_pcts):.1f}%, " + f"max={np.max(zero_pcts):.1f}%" + ) + lines.append( + f" Below 0.1 (x<0.1): " + f"avg={np.mean(lt01_pcts):.1f}%, " + f"max={np.max(lt01_pcts):.1f}%" + ) + lines.append("") + + # Per-state table + lines.append("PER-STATE DETAIL:") + lines.append( + " Err cols = |relative error| (fraction); " + "weight cols = multiplier on national weight (1.0 = unchanged)" + ) + header = ( + f"{'St':<4} {'Status':<14} {'Hit':>5} {'Tot':>5} " + f"{'Viol':>5} {'MeanErr':>8} {'MaxErr':>8} " + f"{'wRMSE':>7} {'wP05':>7} {'wMed':>7} " + f"{'wP95':>7} {'wMax':>8} {'%zero':>6}" + ) + lines.append(header) + lines.append("-" * len(header)) + for _, row in df.iterrows(): + hit = int(row.get("targets_hit", 0)) + tot = int(row.get("targets_total", 0)) + viol = int(row.get("n_violated", 0)) + me = row.get("mean_err", 0) + mx = row.get("max_err", 0) + rmse = row.get("w_rmse", 0) + p5 = row.get("w_p5", 0) + med = row.get("w_median", 0) + p95 = row.get("w_p95", 0) + wmax = row.get("w_max", 0) + dist = row.get("dist_bins", {}) + n_rec = row.get("n_records", 0) + n_zero = 0 + if isinstance(dist, dict): + n_zero = dist.get((0.0, 0.0), {}).get("count", 0) + pct_zero = 100 * n_zero / n_rec if n_rec > 0 else 0 + lines.append( + f"{row['state']:<4} {row['status']:<14} {hit:>5} {tot:>5} " + f"{viol:>5} {me:>8.4f} {mx:>8.4f} " + f"{rmse:>7.3f} {p5:>7.3f} {med:>7.3f} " + f"{p95:>7.3f} {wmax:>8.1f} {pct_zero:>5.1f}%" + ) + lines.append("") + + # Violated targets by variable + all_violated = [] + for _, row in df.iterrows(): + for v in row.get("violated_details", []): + desc = v["desc"] + varname = desc.split("/")[0] + cnt_m = re.search(r"cnt=(\d+)", desc) + cnt_type = int(cnt_m.group(1)) if cnt_m else -1 + abs_miss = abs(v["achieved"] - v["target"]) + all_violated.append( + { + "state": row["state"], + "varname": varname, + "cnt_type": cnt_type, + "pct_err": v["pct_err"], + "target": v["target"], + "achieved": v["achieved"], + "abs_miss": abs_miss, + "desc": desc, + } + ) + if all_violated: + vdf = pd.DataFrame(all_violated) + var_counts = vdf["varname"].value_counts() + lines.append("VIOLATIONS BY VARIABLE:") + for var, cnt in var_counts.items(): + states_with = sorted(vdf[vdf["varname"] == var]["state"].unique()) + lines.append( + f" {var}: {cnt} violations across " + f"{len(states_with)} states" + ) + lines.append("") + + state_counts = vdf["state"].value_counts().head(10) + lines.append("STATES WITH MOST VIOLATIONS:") + for st, cnt in state_counts.items(): + lines.append(f" {st}: {cnt} violated") + lines.append("") + + amt_viol = vdf[vdf["cnt_type"] == 0].sort_values( + ["pct_err", "abs_miss"], ascending=[False, False] + ) + lines.append("WORST 5 AMOUNT VIOLATIONS (by % error):") + if amt_viol.empty: + lines.append(" (none — all amount targets met)") + else: + for _, r in amt_viol.head(5).iterrows(): + lines.append( + f" {r['state']:<4} {r['pct_err']:.3f}% " + f"target=${r['target']:>15,.0f} " + f"achieved=${r['achieved']:>15,.0f} " + f"miss=${r['abs_miss']:>12,.0f} " + f"{_humanize_desc(r['desc'])}" + ) + lines.append("") + + cnt_viol = vdf[vdf["cnt_type"].isin([1, 2])].sort_values( + ["pct_err", "abs_miss"], ascending=[False, False] + ) + lines.append("WORST 5 COUNT VIOLATIONS (by % error):") + if cnt_viol.empty: + lines.append(" (none — all count targets met)") + else: + for _, r in cnt_viol.head(5).iterrows(): + lines.append( + f" {r['state']:<4} {r['pct_err']:.3f}% " + f"target={r['target']:>12,.0f} " + f"achieved={r['achieved']:>12,.0f} " + f"miss={r['abs_miss']:>8,.0f} " + f"{_humanize_desc(r['desc'])}" + ) + lines.append("") + + # Weight diagnostics + lines.extend(_weight_diagnostics(areas, weight_dir)) + + report = "\n".join(lines) + return report + + +def _weight_diagnostics(areas, weight_dir=None): + """ + Combined weight diagnostics: exhaustion + national aggregation. + + Loads TMD data and state weight files once, reuses for both. + Only reads the specific columns needed (not the full allvars). + """ + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + + from tmd.storage import STORAGE_FOLDER + + lines = [] + + # Load national data + tmd_path = STORAGE_FOLDER / "output" / "tmd.csv.gz" + tmd_cols = [ + "RECID", + "s006", + "MARS", + "XTOT", + "data_source", + "e00200", + "e00300", + "e00400", + "e00600", + "e00650", + "e00900", + "e01400", + "e01700", + "e02000", + "e02300", + "e02400", + "e17500", + "e19200", + "e26270", + "n24", + "p22250", + "p23250", + ] + tmd = pd.read_csv(tmd_path, usecols=tmd_cols) + tmd["capgains_net"] = tmd["p22250"] + tmd["p23250"] + n_records = len(tmd) + s006 = tmd["s006"].values + + agi_path = STORAGE_FOLDER / "output" / "cached_c00100.npy" + if agi_path.exists(): + tmd["c00100"] = np.load(agi_path) + + allvars_path = STORAGE_FOLDER / "output" / "cached_allvars.csv" + if allvars_path.exists(): + needed_tc = [ + "c04470", + "c07100", + "c09600", + "c18300", + "c19200", + "c19700", + "iitax", + "payrolltax", + "standard", + ] + avail = pd.read_csv(allvars_path, nrows=0).columns + load_tc = [c for c in needed_tc if c in avail] + if load_tc: + allvars = pd.read_csv(allvars_path, usecols=load_tc) + for col in load_tc: + tmd[col] = allvars[col].values + + # Load all state weights once + weight_sum = np.zeros(n_records) + state_weights = {} + for st in areas: + wpath = weight_dir / f"{st.lower()}_tmd_weights.csv.gz" + if not wpath.exists(): + continue + w = pd.read_csv(wpath, usecols=[_WT_COL])[_WT_COL].values + weight_sum += w + state_weights[st] = w + + n_loaded = len(state_weights) + if n_loaded == 0: + return lines + + # Weight exhaustion + usage = weight_sum / s006 + + _sub = "sum of state weights / national weight" + lines.append(f"WEIGHT EXHAUSTION ({_sub}):") + lines.append( + f" A ratio of 1.0 means the record's national" + f" weight is fully allocated across" + f" {n_loaded} states." + ) + pcts = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100] + quantiles = np.percentile(usage, pcts) + parts = [] + for p, q in zip(pcts, quantiles): + label = {0: "min", 50: "median", 100: "max"}.get(p, f"p{p}") + parts.append(f"{label}={q:.4f}") + lines.append(" " + ", ".join(parts)) + lines.append(f" Mean: {usage.mean():.4f}, Std: {usage.std():.4f}") + + n_over = int((usage > 1.10).sum()) + n_under = int((usage < 0.90).sum()) + lines.append( + f" Over-used (>1.10):" + f" {n_over} ({100 * n_over / n_records:.1f}%) " + f"Under-used (<0.90):" + f" {n_under} ({100 * n_under / n_records:.1f}%)" + ) + for thresh in [2, 5, 10]: + ct = int((usage > thresh).sum()) + if ct > 0: + lines.append(f" Exhaustion > {thresh}x: {ct} records") + lines.append("") + + # Most exhausted records — profile and top states + _mars = {1: "Single", 2: "MFJ", 3: "MFS", 4: "HoH", 5: "Wid"} + _ds = {0: "CPS", 1: "PUF"} + top_idx = np.argsort(usage)[::-1][:5] + lines.append("MOST EXHAUSTED RECORDS (top 5):") + for rank, idx in enumerate(top_idx, 1): + r = tmd.iloc[idx] + exh = usage[idx] + mars = _mars.get(int(r.MARS), f"fs{int(r.MARS)}") + ds = _ds.get(int(r.data_source), "?") + agi = tmd["c00100"].iloc[idx] if "c00100" in tmd else 0 + # Top 3 states by weight for this record + st_wts = [] + for st, w in state_weights.items(): + if w[idx] > 0: + st_wts.append((st, w[idx])) + st_wts.sort(key=lambda x: -x[1]) + top3 = ", ".join(f"{st}={wt:.1f}" for st, wt in st_wts[:3]) + recid = int(r.RECID) + lines.append( + f" {rank}. RECID {recid}: exh={exh:.1f}x," + f" s006={r.s006:.1f}," + f" {ds} {mars}," + f" AGI=${agi:,.0f}" + ) + lines.append( + f" wages=${r.e00200:,.0f}," + f" int=${r.e00300:,.0f}," + f" div=${r.e00600:,.0f}," + f" ptshp=${r.e26270:,.0f}" + ) + lines.append( + f" top states: {top3}" + f" ({len(st_wts)} nonzero of {n_loaded})" + ) + lines.append("") + + # Cross-state aggregation vs national totals + check_vars = [ + ("Returns (s006)", "s006", True), + ("AGI (c00100)", "c00100", False), + ("Wages (e00200)", "e00200", False), + ("Capital gains (capgains_net)", "capgains_net", False), + ("SALT ded (c18300)", "c18300", False), + ("Income tax (iitax)", "iitax", False), + ] + + national = {} + for label, var, is_count in check_vars: + if var not in tmd.columns: + national[var] = None + continue + if var == "s006": + national[var] = float(s006.sum()) + else: + national[var] = float((s006 * tmd[var].values).sum()) + + state_sums = {var: 0.0 for _, var, _ in check_vars} + for _st, w in state_weights.items(): + for _label, var, _is_count in check_vars: + if var not in tmd.columns: + continue + if var == "s006": + state_sums[var] += float(w.sum()) + else: + state_sums[var] += float((w * tmd[var].values).sum()) + + lines.append( + f"CROSS-STATE AGGREGATION vs NATIONAL TOTALS" + f" for SELECTED VARIABLES ({n_loaded} states):" + ) + lines.append( + f" {'Variable':<30} {'National':>16}" + f" {'Sum-of-States':>16} {'Diff%':>8}" + ) + lines.append(" " + "-" * 72) + for label, var, is_count in check_vars: + nat = national[var] + sos = state_sums[var] + if nat is None or nat == 0: + continue + diff_pct = (sos / nat - 1) * 100 + if is_count: + lines.append( + f" {label:<30} {nat:>16,.0f}" + f" {sos:>16,.0f} {diff_pct:>+7.2f}%" + ) + else: + lines.append( + f" {label:<30}" + f" ${nat / 1e9:>14.1f}B" + f" ${sos / 1e9:>14.1f}B" + f" {diff_pct:>+7.2f}%" + ) + lines.append("") + + # Bystander check: untargeted variables + lines.extend(_bystander_check(tmd, s006, state_weights, n_loaded)) + + return lines + + +def _bystander_check(tmd, s006, state_weights, n_loaded): + """ + Check untargeted variables for cross-state aggregation + distortion. These are 'innocent bystanders' that may be + jerked around by weight adjustments aimed at targeted + variables. + """ + lines = [] + + # Untargeted variables to check, grouped by category + # Format: (label, varname, is_count) + bystander_vars = [ + # Tax liability / credits + ("Income tax (iitax)", "iitax", False), + ("Payroll tax", "payrolltax", False), + ("AMT (c09600)", "c09600", False), + ("Total credits (c07100)", "c07100", False), + # Deductions + ("Medical expenses (e17500)", "e17500", False), + ("Student loan int (e19200)", "e19200", False), + ("Itemized ded (c04470)", "c04470", False), + ("Standard deduction", "standard", False), + ("Mortgage int (c19200)", "c19200", False), + ("Charitable (c19700)", "c19700", False), + # Income not directly targeted + ("Tax-exempt int (e00400)", "e00400", False), + ("Qual dividends (e00650)", "e00650", False), + ("Sch C income (e00900)", "e00900", False), + ("IRA distrib (e01400)", "e01400", False), + ("Taxable pensions (e01700)", "e01700", False), + ("Sch E net (e02000)", "e02000", False), + ("Unemployment (e02300)", "e02300", False), + # Demographics + ("Total persons (XTOT)", "XTOT", True), + ("Children <17 (n24)", "n24", True), + ] + + # Compute national and sum-of-states for each + results = [] + for label, var, is_count in bystander_vars: + if var not in tmd.columns: + continue + if var == "XTOT": + nat = float((s006 * tmd[var].values).sum()) + elif is_count: + nat = float((s006 * tmd[var].values).sum()) + else: + nat = float((s006 * tmd[var].values).sum()) + if abs(nat) < 1: + continue + sos = 0.0 + for _st, w in state_weights.items(): + sos += float((w * tmd[var].values).sum()) + diff_pct = (sos / nat - 1) * 100 + results.append((label, var, is_count, nat, sos, diff_pct)) + + # Sort by absolute distortion + results.sort(key=lambda x: -abs(x[5])) + + lines.append( + "BYSTANDER CHECK: UNTARGETED VARIABLES" f" ({n_loaded} states):" + ) + lines.append( + " Variables NOT directly targeted — distortion" + " from weight adjustments." + ) + lines.append( + f" {'Variable':<30} {'National':>16}" + f" {'Sum-of-States':>16} {'Diff%':>8}" + ) + lines.append(" " + "-" * 72) + for label, _var, is_count, nat, sos, diff_pct in results: + flag = " ***" if abs(diff_pct) > 2 else "" + if is_count: + lines.append( + f" {label:<30} {nat:>16,.0f}" + f" {sos:>16,.0f}" + f" {diff_pct:>+7.2f}%{flag}" + ) + else: + lines.append( + f" {label:<30}" + f" ${nat / 1e9:>14.1f}B" + f" ${sos / 1e9:>14.1f}B" + f" {diff_pct:>+7.2f}%{flag}" + ) + + n_flagged = sum(1 for *_, d in results if abs(d) > 2) + lines.append("") + if n_flagged: + lines.append( + f" *** = {n_flagged} variables with" + f" >2% aggregation distortion" + ) + else: + lines.append( + " All untargeted variables within" + " 2% aggregation tolerance." + ) + lines.append("") + + return lines + + +def main(): + parser = argparse.ArgumentParser( + description="Cross-state quality summary report", + ) + parser.add_argument( + "--scope", + default=None, + help="Comma-separated state codes (default: all states)", + ) + args = parser.parse_args() + + areas = None + if args.scope: + areas = [s.strip().upper() for s in args.scope.split(",")] + + report = generate_report(areas) + print(report) + + +if __name__ == "__main__": + main() diff --git a/tmd/areas/solve_weights.py b/tmd/areas/solve_weights.py new file mode 100644 index 00000000..ce8af3b0 --- /dev/null +++ b/tmd/areas/solve_weights.py @@ -0,0 +1,296 @@ +# pylint: disable=import-outside-toplevel +""" +Solve for state weights using Clarabel QP optimizer. + +Reads per-state target CSV files (produced by prepare_targets.py) +and runs the Clarabel constrained QP solver to find weight +multipliers that hit area-specific targets within tolerance. + +Optional exhaustion limiting (--max-exhaustion) runs a two-pass +solve: first unconstrained, then with per-record multiplier caps +to keep cross-state weight exhaustion within bounds. + +Usage: + # All states, 8 parallel workers: + python -m tmd.areas.solve_weights --scope states --workers 8 + + # With exhaustion cap of 5x: + python -m tmd.areas.solve_weights --scope states --workers 8 \ + --max-exhaustion 5 + + # Specific states: + python -m tmd.areas.solve_weights --scope MN,CA,TX --workers 4 +""" + +import argparse +import time + +import numpy as np +import pandas as pd + +from tmd.areas.create_area_weights import ( + AREA_MULTIPLIER_MAX, + STATE_TARGET_DIR, + STATE_WEIGHT_DIR, +) +from tmd.imputation_assumptions import TAXYEAR + +_WT_COL = f"WT{TAXYEAR}" +_MAX_EXHAUST_ITERATIONS = 5 + + +def solve_state_weights( + scope="states", + num_workers=1, + force=True, + max_exhaustion=None, +): + """ + Run the Clarabel solver for the specified areas. + + Parameters + ---------- + scope : str + 'states' or comma-separated state codes. + num_workers : int + Number of parallel worker processes. + force : bool + Recompute all areas even if weight files are up-to-date. + max_exhaustion : float or None + If set, limit per-record cross-state weight exhaustion + to this multiple of the national weight. Runs iterative + two-pass solve. + """ + from tmd.areas.batch_weights import run_batch + + specific = _parse_scope(scope) + if specific: + area_filter = ",".join(a.lower() for a in specific) + else: + area_filter = "states" + + t0 = time.time() + + # --- Pass 1: unconstrained solve --- + print("Pass 1: solving state weights...") + run_batch( + num_workers=num_workers, + area_filter=area_filter, + force=force, + target_dir=STATE_TARGET_DIR, + weight_dir=STATE_WEIGHT_DIR, + ) + + if max_exhaustion is None: + elapsed = time.time() - t0 + print(f"Total solve time: {elapsed:.1f}s") + return + + # --- Exhaustion-limited iterative passes --- + for iteration in range(1, _MAX_EXHAUST_ITERATIONS + 1): + exhaustion, state_weights = _compute_exhaustion(STATE_WEIGHT_DIR) + n_over = int((exhaustion > max_exhaustion).sum()) + max_exh = exhaustion.max() + print( + f"\nExhaustion check (pass {iteration}):" + f" max={max_exh:.2f}," + f" {n_over} records > {max_exhaustion}x" + ) + if n_over == 0: + print("All records within exhaustion limit.") + break + + # Compute and write per-record caps + affected = _write_exhaustion_caps( + exhaustion, + state_weights, + max_exhaustion, + STATE_WEIGHT_DIR, + STATE_TARGET_DIR, + ) + if not affected: + break + + # Re-solve affected states + af = ",".join(affected) + print( + f"Pass {iteration + 1}: re-solving" + f" {len(affected)} states with caps..." + ) + run_batch( + num_workers=num_workers, + area_filter=af, + force=True, + target_dir=STATE_TARGET_DIR, + weight_dir=STATE_WEIGHT_DIR, + ) + + # Clean up cap files for this iteration + _cleanup_caps(STATE_WEIGHT_DIR, affected) + else: + exhaustion, _ = _compute_exhaustion(STATE_WEIGHT_DIR) + n_still = int((exhaustion > max_exhaustion).sum()) + if n_still > 0: + print( + f"Warning: {n_still} records still exceed" + f" {max_exhaustion}x after" + f" {_MAX_EXHAUST_ITERATIONS} iterations" + f" (max={exhaustion.max():.2f}x)" + ) + + elapsed = time.time() - t0 + print(f"Total solve time: {elapsed:.1f}s") + + +def _compute_exhaustion(weight_dir): + """ + Compute per-record exhaustion across all state weight files. + + Returns (exhaustion_array, state_weights_dict). + """ + from tmd.storage import STORAGE_FOLDER + + s006 = pd.read_csv( + STORAGE_FOLDER / "output" / "tmd.csv.gz", + usecols=["s006"], + )["s006"].values + n = len(s006) + + weight_sum = np.zeros(n) + state_weights = {} + for wpath in sorted(weight_dir.glob("*_tmd_weights.csv.gz")): + area = wpath.name.split("_")[0] + w = pd.read_csv(wpath, usecols=[_WT_COL])[_WT_COL].values + weight_sum += w + state_weights[area] = w + + exhaustion = weight_sum / s006 + return exhaustion, state_weights + + +def _write_exhaustion_caps( + exhaustion, + state_weights, + max_exhaustion, + weight_dir, + target_dir, +): + """ + Compute per-record multiplier caps and write cap files. + + For over-exhausted records, scale each state's multiplier + cap proportionally to current usage so total exhaustion + equals max_exhaustion. + + Returns list of affected area codes. + """ + from tmd.storage import STORAGE_FOLDER + + s006 = pd.read_csv( + STORAGE_FOLDER / "output" / "tmd.csv.gz", + usecols=["s006"], + )["s006"].values + nat_pop = pd.read_csv( + STORAGE_FOLDER / "output" / "tmd.csv.gz", + usecols=["s006", "XTOT"], + ) + nat_pop = (nat_pop["s006"] * nat_pop["XTOT"]).sum() + + over_mask = exhaustion > max_exhaustion + if not over_mask.any(): + return [] + + # Scale factor per record: how much to shrink + scale = np.ones_like(exhaustion) + scale[over_mask] = max_exhaustion / exhaustion[over_mask] + + affected_areas = set() + for area, weights in state_weights.items(): + # Get pop_share for this area + tpath = target_dir / f"{area}_targets.csv" + if not tpath.exists(): + continue + targets_df = pd.read_csv(tpath, comment="#", nrows=1) + xtot_target = targets_df.iloc[0]["target"] + pop_share = xtot_target / nat_pop + + # Current multiplier for each record + w0 = pop_share * s006 + with np.errstate(divide="ignore", invalid="ignore"): + current_x = np.where(w0 > 0, weights / w0, 0.0) + + # New cap = current_x * scale (only tighten, never loosen) + new_caps = current_x * scale + # Don't exceed global max + new_caps = np.minimum(new_caps, AREA_MULTIPLIER_MAX) + + # Only write if some records are actually capped + n_capped = int((new_caps < AREA_MULTIPLIER_MAX - 1e-6).sum()) + if n_capped > 0: + caps_path = weight_dir / f"{area}_record_caps.npy" + np.save(caps_path, new_caps) + affected_areas.add(area) + + return sorted(affected_areas) + + +def _cleanup_caps(weight_dir, areas): + """Remove cap files after a solve pass.""" + for area in areas: + caps_path = weight_dir / f"{area}_record_caps.npy" + caps_path.unlink(missing_ok=True) + + +def _parse_scope(scope): + """Parse scope string into list of state codes or None.""" + _EXCLUDE = {"US", "PR", "OA"} + scope_lower = scope.lower().strip() + if scope_lower in ("states", "all"): + return None + codes = [c.strip().upper() for c in scope.split(",") if c.strip()] + return [c for c in codes if len(c) == 2 and c not in _EXCLUDE] + + +def main(): + """CLI entry point.""" + parser = argparse.ArgumentParser( + description=("Solve for state weights using Clarabel QP optimizer"), + ) + parser.add_argument( + "--scope", + default="states", + help=("'states' or comma-separated state codes" " (e.g., 'MN,CA,TX')"), + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel solver workers (default: 1)", + ) + parser.add_argument( + "--force", + action="store_true", + default=True, + help="Recompute all areas even if up-to-date", + ) + parser.add_argument( + "--max-exhaustion", + type=float, + default=None, + help=( + "Max per-record cross-state weight exhaustion" + " (e.g., 5.0). Runs iterative solve to enforce." + ), + ) + args = parser.parse_args() + + solve_state_weights( + scope=args.scope, + num_workers=args.workers, + force=args.force, + max_exhaustion=args.max_exhaustion, + ) + + +if __name__ == "__main__": + main() diff --git a/tmd/areas/sweep_params.py b/tmd/areas/sweep_params.py new file mode 100644 index 00000000..6098828e --- /dev/null +++ b/tmd/areas/sweep_params.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +""" +Parameter sweep for state weight solver. + +Tests combinations of multiplier_max and weight_penalty, +solving all 51 states for each combo and reporting: + - Target violations + - Weight exhaustion + - Solve time + +Usage: + python -m tmd.areas.sweep_params +""" + +import io +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed + +import numpy as np +import yaml + +from tmd.areas.create_area_weights import ( + AREA_CONSTRAINT_TOL, + AREA_MAX_ITER, + AREA_MULTIPLIER_MIN, + AREA_SLACK_PENALTY, + POPFILE_PATH, + STATE_TARGET_DIR, + _build_constraint_matrix, + _drop_impossible_targets, + _load_taxcalc_data, + _solve_area_qp, +) +from tmd.areas.prepare.constants import ALL_STATES +from tmd.imputation_assumptions import TAXYEAR + +_WT_COL = f"WT{TAXYEAR}" + +# --- Parameter grid --- +GRID_MULTIPLIER_MAX = [10, 15, 25, 100] +GRID_WEIGHT_PENALTY = [1.0, 10.0, 100.0] + +NUM_WORKERS = 8 + +# Module-level cache +_VDF = None +_POP = None + + +def _init(): + """Load data once.""" + global _VDF, _POP # pylint: disable=global-statement + if _VDF is not None: + return + _VDF = _load_taxcalc_data() + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + _POP = yaml.safe_load(pf.read()) + + +def _solve_one(args): + """Solve one state with given params. Returns stats dict.""" + area, mult_max, wt_pen = args + _init() + vdf = _VDF + out = io.StringIO() + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=STATE_TARGET_DIR + ) + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + n_records = B_csc.shape[1] + t0 = time.time() + x_opt, _s_lo, _s_hi, info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=AREA_CONSTRAINT_TOL, + slack_penalty=AREA_SLACK_PENALTY, + max_iter=AREA_MAX_ITER, + multiplier_min=AREA_MULTIPLIER_MIN, + multiplier_max=mult_max, + weight_penalty=wt_pen, + out=out, + ) + elapsed = time.time() - t0 + + # Compute stats + achieved = np.asarray(B_csc @ x_opt).ravel() + rel_errors = np.abs(achieved - targets) / np.maximum(np.abs(targets), 1.0) + eps = 1e-9 + n_violated = int((rel_errors > AREA_CONSTRAINT_TOL + eps).sum()) + max_viol = float(rel_errors.max() * 100) + + # Weight stats + w0 = pop_share * vdf.s006.values + final_weights = x_opt * w0 + + return { + "area": area, + "mult_max": mult_max, + "wt_pen": wt_pen, + "n_targets": len(targets), + "n_violated": n_violated, + "max_viol_pct": max_viol, + "x_max": float(x_opt.max()), + "x_rmse": float(np.sqrt(np.mean((x_opt - 1.0) ** 2))), + "pct_zero": float((x_opt < 1e-6).mean() * 100), + "status": info["status"], + "elapsed": elapsed, + "pop_share": pop_share, + "final_weights": final_weights, + } + + +def run_sweep(): + """Run parameter sweep and print results.""" + print("Loading TMD data...") + _init() + + areas = [s.lower() for s in ALL_STATES] + combos = [ + (mm, wp) for mm in GRID_MULTIPLIER_MAX for wp in GRID_WEIGHT_PENALTY + ] + + s006 = _VDF.s006.values + n_records = len(s006) + + print( + f"Sweep: {len(combos)} parameter combos" + f" x {len(areas)} states" + f" = {len(combos) * len(areas)} solves" + ) + print( + f"Grid: mult_max={GRID_MULTIPLIER_MAX}," + f" weight_penalty={GRID_WEIGHT_PENALTY}" + ) + print(f"Workers: {NUM_WORKERS}") + print() + + results = [] + + for combo_idx, (mm, wp) in enumerate(combos): + label = f"mult_max={mm:>3}, wt_pen={wp:>5.1f}" + sys.stdout.write(f"[{combo_idx + 1}/{len(combos)}] {label}...") + sys.stdout.flush() + + t0 = time.time() + tasks = [(area, mm, wp) for area in areas] + + combo_results = [] + with ProcessPoolExecutor( + max_workers=NUM_WORKERS, + initializer=_init, + ) as executor: + futures = { + executor.submit(_solve_one, task): task for task in tasks + } + for future in as_completed(futures): + combo_results.append(future.result()) + + elapsed = time.time() - t0 + + # Aggregate stats + total_violated = sum(r["n_violated"] for r in combo_results) + max_viol = max(r["max_viol_pct"] for r in combo_results) + avg_rmse = np.mean([r["x_rmse"] for r in combo_results]) + avg_zero = np.mean([r["pct_zero"] for r in combo_results]) + + # Compute exhaustion + weight_sum = np.zeros(n_records) + for r in combo_results: + weight_sum += r["final_weights"] + exhaustion = weight_sum / s006 + max_exh = float(exhaustion.max()) + p99_exh = float(np.percentile(exhaustion, 99)) + n_over5 = int((exhaustion > 5).sum()) + n_over10 = int((exhaustion > 10).sum()) + + summary = { + "mult_max": mm, + "wt_pen": wp, + "total_violated": total_violated, + "max_viol_pct": max_viol, + "avg_rmse": avg_rmse, + "avg_pct_zero": avg_zero, + "max_exhaustion": max_exh, + "p99_exhaustion": p99_exh, + "n_over5x": n_over5, + "n_over10x": n_over10, + "elapsed": elapsed, + } + results.append(summary) + + sys.stdout.write( + f" {elapsed:.0f}s |" + f" viol={total_violated:>4}" + f" maxV={max_viol:.2f}%" + f" wRMSE={avg_rmse:.3f}" + f" %zero={avg_zero:.1f}%" + f" maxExh={max_exh:.1f}" + f" >5x={n_over5}" + f" >10x={n_over10}\n" + ) + + # Summary table + print("\n" + "=" * 100) + print("PARAMETER SWEEP RESULTS") + print("=" * 100) + cols = [ + "mMax", + "wPen", + "Viol", + "MaxV%", + "wRMSE", + "%zero", + "MaxExh", + "p99Exh", + "O5x", + "O10x", + "Time", + ] + widths = [5, 6, 5, 6, 6, 6, 7, 7, 5, 5, 6] + hdr = " ".join(f"{c:>{w}}" for c, w in zip(cols, widths)) + print(hdr) + print("-" * len(hdr)) + for r in results: + mm = r["mult_max"] + wp = r["wt_pen"] + tv = r["total_violated"] + mv = r["max_viol_pct"] + rm = r["avg_rmse"] + pz = r["avg_pct_zero"] + me = r["max_exhaustion"] + pe = r["p99_exhaustion"] + o5 = r["n_over5x"] + o10 = r["n_over10x"] + et = r["elapsed"] + print( + f"{mm:>5} {wp:>6.1f}" + f" {tv:>5} {mv:>6.2f}" + f" {rm:>6.3f} {pz:>6.1f}" + f" {me:>7.1f} {pe:>7.3f}" + f" {o5:>5} {o10:>5}" + f" {et:>5.0f}s" + ) + print() + print("Baseline: mult_max=100, wt_pen=1.0") + + +if __name__ == "__main__": + run_sweep()