From a5d6ec27c630a511afd4c03afb8648bea80a6e57 Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 11:07:46 -0400 Subject: [PATCH 01/10] Add Clarabel QP solver for state weight optimization Port state weight solver pipeline from state-weights-clarabel branch: - Clarabel constrained QP solver with elastic slack (0.5% tolerance) - Parallel batch runner with worker-cached TMD data - Cross-state quality report (log parsing, weight exhaustion, aggregation) - Standalone CLI: python -m tmd.areas.solve_weights --scope states --workers 8 - 11 tests covering solver, log parser, scope parsing, and area filtering Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_solve_weights.py | 255 +++++++++ tmd/areas/batch_weights.py | 355 +++++++++++++ tmd/areas/create_area_weights_clarabel.py | 601 ++++++++++++++++++++++ tmd/areas/quality_report.py | 589 +++++++++++++++++++++ tmd/areas/solve_weights.py | 112 ++++ 5 files changed, 1912 insertions(+) create mode 100644 tests/test_solve_weights.py create mode 100644 tmd/areas/batch_weights.py create mode 100644 tmd/areas/create_area_weights_clarabel.py create mode 100644 tmd/areas/quality_report.py create mode 100644 tmd/areas/solve_weights.py diff --git a/tests/test_solve_weights.py b/tests/test_solve_weights.py new file mode 100644 index 00000000..2b52af17 --- /dev/null +++ b/tests/test_solve_weights.py @@ -0,0 +1,255 @@ +""" +Tests for the Clarabel-based state weight solver pipeline. + +Tests: + - Clarabel solver on faux xx area + - Quality report log parser + - CLI scope parsing + - Batch area filtering +""" + +import io +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from tmd.areas import AREAS_FOLDER +from tmd.areas.batch_weights import _filter_areas +from tmd.areas.create_area_weights_clarabel import ( + AREA_CONSTRAINT_TOL, + _build_constraint_matrix, + _drop_impossible_targets, + _load_taxcalc_data, + _solve_area_qp, + create_area_weights_file_clarabel, +) +from tmd.areas.quality_report import ( + _humanize_desc, + parse_log, +) +from tmd.areas.solve_weights import ( + _parse_scope, +) +from tmd.imputation_assumptions import TAXYEAR + +# --- Solver test on faux xx area --- + + +def test_clarabel_solver_xx(): + """ + Solve faux xx area with Clarabel and verify targets are hit. + """ + # xx targets are in the flat targets/ directory + target_dir = AREAS_FOLDER / "targets" + + with tempfile.TemporaryDirectory() as tmpdir: + weight_dir = Path(tmpdir) + rc = create_area_weights_file_clarabel( + "xx", + write_log=True, + write_file=True, + target_dir=target_dir, + weight_dir=weight_dir, + ) + assert rc == 0, "create_area_weights_file_clarabel returned non-zero" + + # Verify weights file was created + wpath = weight_dir / "xx_tmd_weights.csv.gz" + assert wpath.exists(), "Weight file not created" + + # Verify weights file has expected columns + wdf = pd.read_csv(wpath) + assert f"WT{TAXYEAR}" in wdf.columns + assert f"WT{TAXYEAR + 1}" in wdf.columns + + # Verify all weights are non-negative + assert (wdf[f"WT{TAXYEAR}"] >= 0).all(), "Negative weights found" + + # Verify log file was created + logpath = weight_dir / "xx.log" + assert logpath.exists(), "Log file not created" + + # Verify log contains solver status + log_text = logpath.read_text() + assert "Solver status:" in log_text + assert "TARGET ACCURACY" in log_text + + +def test_clarabel_solver_xx_targets_hit(): + """ + Verify the Clarabel solver actually hits the xx area targets + within the constraint tolerance. + """ + target_dir = AREAS_FOLDER / "targets" + out = io.StringIO() + + vdf = _load_taxcalc_data() + B_csc, targets, labels, _pop_share = _build_constraint_matrix( + "xx", vdf, out, target_dir=target_dir + ) + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + n_records = B_csc.shape[1] + x_opt, _s_lo, _s_hi, info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + out=out, + ) + + # Check that solver succeeded + status = info["status"] + assert "Solved" in status, f"Solver status: {status}" + + # Check all targets hit within tolerance + achieved = np.asarray(B_csc @ x_opt).ravel() + rel_errors = np.abs(achieved - targets) / np.maximum(np.abs(targets), 1.0) + eps = 1e-9 + n_violated = int((rel_errors > AREA_CONSTRAINT_TOL + eps).sum()) + assert n_violated == 0, ( + f"{n_violated} targets violated; " + f"max error = {rel_errors.max() * 100:.3f}%" + ) + + +# --- Quality report log parser --- + + +_V = "c00100/cnt=1/scope=0" +_V1 = ( # noqa: E501 + f" 0.500% | target= 12345 | achieved= 12407" + f" | {_V}/agi=[-9e+99,1.0)/fs=0" +) +_V2 = ( + f" 0.489% | target= 567 | achieved= 570" + f" | {_V}/agi=[1e+06,9e+99)/fs=1" +) +_V3 = ( + f" 0.478% | target= 1234 | achieved= 1240" + f" | {_V}/agi=[1e+06,9e+99)/fs=2" +) + + +def test_parse_log_solved(): + """Test log parser with a synthetic solved log.""" + _hit = " targets hit: 175/178 (tolerance: +/-0.5% + eps)" + _mult = ( + " min=0.000000, p5=0.450000," + + " median=0.980000," + + " p95=1.650000, max=45.000000" + ) + _d1 = " [ 0.0000, 0.0000): 2662 ( 1.24%)" + _d2 = " [ 0.0000, 0.1000): 5000 ( 2.33%)" + log_content = "\n".join( + [ + "Solver status: Solved", + "Iterations: 42", + "Solve time: 3.14s", + "TARGET ACCURACY (178 targets):", + " mean |relative error|: 0.001234", + " max |relative error|: 0.004567", + _hit, + " VIOLATED: 3 targets", + _V1, + _V2, + _V3, + "MULTIPLIER DISTRIBUTION:", + _mult, + " RMSE from 1.0: 0.550000", + " distribution (n=214377):", + _d1, + _d2, + "ALL CONSTRAINTS SATISFIED WITHOUT SLACK", + "", + ] + ) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".log", delete=False + ) as f: + f.write(log_content) + f.flush() + result = parse_log(Path(f.name)) + + assert result["status"] == "Solved" + assert result["solve_time"] == pytest.approx(3.14) + assert result["mean_err"] == pytest.approx(0.001234) + assert result["max_err"] == pytest.approx(0.004567) + assert result["targets_hit"] == 175 + assert result["targets_total"] == 178 + assert result["n_violated"] == 3 + assert result["w_min"] == pytest.approx(0.0) + assert result["w_median"] == pytest.approx(0.98) + assert result["w_rmse"] == pytest.approx(0.55) + assert result["n_records"] == 214377 + assert len(result["violated_details"]) == 3 + + +def test_parse_log_missing(): + """Test log parser with non-existent file.""" + result = parse_log(Path("/nonexistent/path.log")) + assert result["status"] == "NO LOG" + + +# --- Scope parsing --- + + +def test_solve_weights_parse_scope_states(): + """Test scope parsing for 'states'.""" + assert _parse_scope("states") is None + assert _parse_scope("all") is None + + +def test_solve_weights_parse_scope_specific(): + """Test scope parsing for specific states.""" + result = _parse_scope("MN,CA,TX") + assert result == ["MN", "CA", "TX"] + + +def test_solve_weights_parse_scope_excludes(): + """Test scope parsing excludes PR, US, OA.""" + result = _parse_scope("MN,PR,US,OA,CA") + assert result == ["MN", "CA"] + + +# --- Batch area filtering --- + + +def test_filter_areas_states(): + """Test batch filter for states.""" + areas = ["al", "ca", "mn", "mn01", "xx"] + assert _filter_areas(areas, "states") == [ + "al", + "ca", + "mn", + ] + + +def test_filter_areas_cds(): + """Test batch filter for CDs.""" + areas = ["al", "ca", "mn01", "mn02"] + assert _filter_areas(areas, "cds") == ["mn01", "mn02"] + + +def test_filter_areas_specific(): + """Test batch filter for specific areas.""" + areas = ["al", "ca", "mn", "tx"] + assert _filter_areas(areas, "mn,tx") == ["mn", "tx"] + + +# --- Humanize description --- + + +def test_humanize_desc(): + """Test human-readable label generation.""" + desc = "c00100/cnt=1/scope=1" + "/agi=[500000.0,1000000.0)/fs=4" + result = _humanize_desc(desc) + assert "c00100" in result + assert "returns" in result + assert "HoH" in result + assert "$500K" in result diff --git a/tmd/areas/batch_weights.py b/tmd/areas/batch_weights.py new file mode 100644 index 00000000..9732da3f --- /dev/null +++ b/tmd/areas/batch_weights.py @@ -0,0 +1,355 @@ +# pylint: disable=import-outside-toplevel,global-statement +""" +Batch area weight optimization — parallel processing for all areas. + +TMD data loaded once per worker (not once per area). +Progress reporting with ETA. +Uses concurrent.futures for clean parallel execution. + +Usage: + python -m tmd.areas.batch_weights --scope states --workers 8 + python -m tmd.areas.batch_weights --scope MN,CA,TX --workers 4 +""" + +import argparse +import io +import re +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path + +import numpy as np +import pandas as pd +import yaml + +from tmd.areas.create_area_weights import valid_area + +# Module-level cache for TMD data (one per worker process) +_WORKER_VDF = None +_WORKER_POP = None +_WORKER_TARGET_DIR = None +_WORKER_WEIGHT_DIR = None + + +def _init_worker(target_dir=None, weight_dir=None): + """Load TMD data once per worker process.""" + global _WORKER_VDF, _WORKER_POP + global _WORKER_TARGET_DIR, _WORKER_WEIGHT_DIR + if target_dir is not None: + _WORKER_TARGET_DIR = Path(target_dir) + if weight_dir is not None: + _WORKER_WEIGHT_DIR = Path(weight_dir) + if _WORKER_VDF is not None: + return + from tmd.areas.create_area_weights_clarabel import ( + POPFILE_PATH, + _load_taxcalc_data, + ) + + _WORKER_VDF = _load_taxcalc_data() + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + _WORKER_POP = yaml.safe_load(pf.read()) + + +def _solve_one_area(area): + """ + Solve area weights for one area using cached worker data. + + Returns (area, elapsed, n_targets, n_violated, status, max_viol_pct). + """ + _init_worker() + from tmd.areas.create_area_weights_clarabel import ( + AREA_CONSTRAINT_TOL, + AREA_MAX_ITER, + AREA_MULTIPLIER_MAX, + AREA_MULTIPLIER_MIN, + AREA_SLACK_PENALTY, + FIRST_YEAR, + LAST_YEAR, + _build_constraint_matrix, + _drop_impossible_targets, + _print_multiplier_diagnostics, + _print_slack_diagnostics, + _print_target_diagnostics, + _read_params, + _solve_area_qp, + ) + + t0 = time.time() + out = io.StringIO() + + # Build and solve + vdf = _WORKER_VDF + tgt_dir = _WORKER_TARGET_DIR + wgt_dir = _WORKER_WEIGHT_DIR + params = _read_params(area, out, target_dir=tgt_dir) + constraint_tol = params.get( + "constraint_tol", + params.get("target_ratio_tolerance", AREA_CONSTRAINT_TOL), + ) + slack_penalty = params.get("slack_penalty", AREA_SLACK_PENALTY) + max_iter = params.get("max_iter", AREA_MAX_ITER) + multiplier_min = params.get("multiplier_min", AREA_MULTIPLIER_MIN) + multiplier_max = params.get("multiplier_max", AREA_MULTIPLIER_MAX) + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=tgt_dir + ) + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + n_records = B_csc.shape[1] + x_opt, s_lo, s_hi, info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=constraint_tol, + slack_penalty=slack_penalty, + max_iter=max_iter, + multiplier_min=multiplier_min, + multiplier_max=multiplier_max, + out=out, + ) + + # Diagnostics + n_violated = _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out + ) + _print_multiplier_diagnostics(x_opt, out) + _print_slack_diagnostics(s_lo, s_hi, targets, labels, out) + + # Compute max violation percentage for summary + achieved = np.asarray(B_csc @ x_opt).ravel() + rel_errors = np.abs(achieved - targets) / np.maximum(np.abs(targets), 1.0) + eps = 1e-9 + viol_mask = rel_errors > constraint_tol + eps + max_viol_pct = ( + float(rel_errors[viol_mask].max() * 100) if viol_mask.any() else 0.0 + ) + + # Write log + logpath = wgt_dir / f"{area}.log" + logpath.parent.mkdir(parents=True, exist_ok=True) + with open(logpath, "w", encoding="utf-8") as f: + f.write(out.getvalue()) + + # Write weights file + w0 = pop_share * vdf.s006.values + wght_area = x_opt * w0 + + wdict = {f"WT{FIRST_YEAR}": wght_area} + cum_pop_growth = 1.0 + pop = _WORKER_POP + for year in range(FIRST_YEAR + 1, LAST_YEAR + 1): + annual_pop_growth = pop[year] / pop[year - 1] + cum_pop_growth *= annual_pop_growth + wdict[f"WT{year}"] = wght_area * cum_pop_growth + + wdf = pd.DataFrame.from_dict(wdict) + awpath = wgt_dir / f"{area}_tmd_weights.csv.gz" + wdf.to_csv( + awpath, + index=False, + float_format="%.5f", + compression="gzip", + ) + + elapsed = time.time() - t0 + return ( + area, + elapsed, + len(targets), + n_violated, + info["status"], + max_viol_pct, + ) + + +def _list_target_areas(target_dir=None): + """Return sorted list of area codes with target files.""" + from tmd.areas.create_area_weights_clarabel import STATE_TARGET_DIR + + tfolder = target_dir or STATE_TARGET_DIR + tpaths = sorted(tfolder.glob("*_targets.csv")) + areas = [] + for tpath in tpaths: + area = tpath.name.split("_")[0] + old_stderr = sys.stderr + sys.stderr = io.StringIO() + ok = valid_area(area) + sys.stderr = old_stderr + if ok: + areas.append(area) + return areas + + +def _filter_areas(areas, area_filter): + """Filter areas by type: 'states', 'cds', or 'all'.""" + if area_filter == "all": + return areas + if area_filter == "states": + return [a for a in areas if len(a) == 2 and not re.match(r"[x-z]", a)] + if area_filter == "cds": + return [a for a in areas if len(a) > 2] + # Treat as comma-separated list + requested = [a.strip() for a in area_filter.split(",")] + return [a for a in areas if a in requested] + + +def run_batch( + num_workers=1, + area_filter="all", + force=False, + target_dir=None, + weight_dir=None, +): + """ + Run area weight optimization for multiple areas in parallel. + + Parameters + ---------- + num_workers : int + Number of parallel worker processes. + area_filter : str + 'states', 'cds', 'all', or comma-separated area codes. + force : bool + If True, recompute all areas even if up-to-date. + target_dir : Path, optional + Directory containing target CSVs. + weight_dir : Path, optional + Directory for weight output. + """ + from tmd.areas.create_area_weights_clarabel import ( + STATE_TARGET_DIR, + STATE_WEIGHT_DIR, + ) + + if target_dir is None: + target_dir = STATE_TARGET_DIR + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + all_areas = _list_target_areas(target_dir=target_dir) + areas = _filter_areas(all_areas, area_filter) + + if not areas: + print("No areas to process.") + return + + # Filter to out-of-date areas unless force=True + if not force: + from tmd.areas.make_all import time_of_newest_other_dependency + + newest_dep = time_of_newest_other_dependency() + todo = [] + for area in areas: + wpath = weight_dir / f"{area}_tmd_weights.csv.gz" + tpath = target_dir / f"{area}_targets.csv" + if wpath.exists(): + wtime = wpath.stat().st_mtime + ttime = tpath.stat().st_mtime + if wtime > max(newest_dep, ttime): + continue + todo.append(area) + areas = todo + + if not areas: + print("All areas up-to-date. Use --force to recompute.") + return + + n = len(areas) + print(f"Processing {n} areas with {num_workers} workers...") + print( + "(Areas shown in completion order, which varies with" + " parallel workers.)" + ) + + # Ensure weights directory exists + weight_dir.mkdir(parents=True, exist_ok=True) + + t_start = time.time() + completed = 0 + violated_areas = [] + worst_viol_pct = 0.0 + max_id_width = max(len(a) for a in areas) + + with ProcessPoolExecutor( + max_workers=num_workers, + initializer=_init_worker, + initargs=(str(target_dir), str(weight_dir)), + ) as executor: + futures = { + executor.submit(_solve_one_area, area): area for area in areas + } + for future in as_completed(futures): + area = futures[future] + try: + area_code, _, _, n_viol, _, mv_pct = future.result() + completed += 1 + + if n_viol > 0: + violated_areas.append((area_code, n_viol)) + worst_viol_pct = max(worst_viol_pct, mv_pct) + + # Start new line every 10 areas with count prefix + if (completed - 1) % 10 == 0: + sys.stdout.write(f"\n{completed:4d} ") + sys.stdout.write(f" {area_code.ljust(max_id_width)}") + sys.stdout.flush() + + # After every 10th area (or the last), print elapsed + if completed % 10 == 0 or completed == n: + elapsed_total = time.time() - t_start + sys.stdout.write(f" [{elapsed_total:.0f}s elapsed]") + sys.stdout.flush() + + except Exception: + completed += 1 + sys.stdout.write(f" {area}:FAIL") + sys.stdout.flush() + sys.stdout.write("\n") + + total = time.time() - t_start + print(f"\nCompleted {completed}/{n} areas in {total:.1f}s") + if violated_areas: + total_viol = sum(v for _, v in violated_areas) + print( + f"{len(violated_areas)} areas had violated targets" + f" ({total_viol} targets total)." + f" Largest violation: {worst_viol_pct:.2f}%." + ) + cmd = "python -m tmd.areas.quality_report" + print(f"Run: {cmd} for full details.") + else: + print("All targets hit within tolerance.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Batch area weight optimization" + ) + parser.add_argument( + "--scope", + type=str, + default="states", + help="'states', 'cds', 'all', or comma-separated codes", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel workers (default: 1)", + ) + parser.add_argument( + "--force", + action="store_true", + help="Recompute all areas even if up-to-date", + ) + args = parser.parse_args() + run_batch( + num_workers=args.workers, + area_filter=args.scope, + force=args.force, + ) diff --git a/tmd/areas/create_area_weights_clarabel.py b/tmd/areas/create_area_weights_clarabel.py new file mode 100644 index 00000000..40a1d4c5 --- /dev/null +++ b/tmd/areas/create_area_weights_clarabel.py @@ -0,0 +1,601 @@ +""" +Clarabel-based constrained QP area weight optimization. + +Finds weight multipliers (x) for each record such that area-specific +weighted sums match area targets within a tolerance band, while +minimizing deviation from population-proportional weights. + +Formulation: + minimize sum((x_i - 1)^2) + subject to t_j*(1-eps) <= (Bx)_j <= t_j*(1+eps) [target bounds] + x_min <= x_i <= x_max [multiplier bounds] + +where: + x_i = area_weight_i / (pop_share * national_weight_i) + pop_share = area_population / national_population + B[j,i] = (pop_share * national_weight_i) * A[j,i] + + x_i = 1 means record i gets its population-proportional share. + The optimizer adjusts x_i to hit area-specific targets. + +Elastic slack variables handle infeasibility gracefully: + minimize sum((x_i - 1)^2) + M * sum(s^2) + subject to lb_j <= (Bx)_j + s_lo - s_hi <= ub_j, s >= 0 + +Follows the same QP construction as tmd/utils/reweight.py. +""" + +import sys +import time + +import clarabel +import numpy as np +import pandas as pd +import yaml +from scipy.sparse import ( + csc_matrix, + diags as spdiags, + eye as speye, + hstack, + vstack, +) + +from tmd.areas import AREAS_FOLDER +from tmd.imputation_assumptions import POPULATION_FILE, TAXYEAR +from tmd.storage import STORAGE_FOLDER + +FIRST_YEAR = TAXYEAR +LAST_YEAR = 2034 +INFILE_PATH = STORAGE_FOLDER / "output" / "tmd.csv.gz" +POPFILE_PATH = STORAGE_FOLDER / "input" / POPULATION_FILE +TAXCALC_AGI_CACHE = STORAGE_FOLDER / "output" / "cached_c00100.npy" +CACHED_ALLVARS_PATH = STORAGE_FOLDER / "output" / "cached_allvars.csv" + +# Tax-Calculator output variables to load from cached_allvars for targeting +CACHED_TC_OUTPUTS = [ + "c18300", + "c04470", + "c02500", + "c19200", + "c19700", + "eitc", + "ctc_total", +] + +# Default solver parameters +AREA_CONSTRAINT_TOL = 0.005 +AREA_SLACK_PENALTY = 1e6 +AREA_MAX_ITER = 2000 +AREA_MULTIPLIER_MIN = 0.0 +AREA_MULTIPLIER_MAX = 100.0 + +# Default target/weight directories for states +STATE_TARGET_DIR = AREAS_FOLDER / "targets" / "states" +STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states" + + +def _load_taxcalc_data(): + """Load TMD data with cached AGI and selected Tax-Calculator outputs.""" + vdf = pd.read_csv(INFILE_PATH) + vdf["c00100"] = np.load(TAXCALC_AGI_CACHE) + if CACHED_ALLVARS_PATH.exists(): + allvars = pd.read_csv(CACHED_ALLVARS_PATH, usecols=CACHED_TC_OUTPUTS) + for col in CACHED_TC_OUTPUTS: + if col in allvars.columns: + vdf[col] = allvars[col].values + # Synthetic combined variable for net capital gains targeting + if "p22250" in vdf.columns and "p23250" in vdf.columns: + vdf["capgains_net"] = vdf["p22250"] + vdf["p23250"] + assert np.all(vdf.s006 > 0), "Not all weights are positive" + return vdf + + +def _read_params(area, out, target_dir=None): + """Read optional area-specific parameters YAML file.""" + if target_dir is None: + target_dir = STATE_TARGET_DIR + params = {} + pfile = f"{area}_params.yaml" + params_path = target_dir / pfile + if params_path.exists(): + with open(params_path, "r", encoding="utf-8") as f: + params = yaml.safe_load(f.read()) + exp_params = [ + "target_ratio_tolerance", + "dump_all_target_deviations", + "delta_init_value", + "delta_max_loops", + "iprint", + "constraint_tol", + "slack_penalty", + "max_iter", + "multiplier_min", + "multiplier_max", + ] + act_params = list(params.keys()) + all_ok = len(set(act_params)) == len(act_params) + for param in act_params: + if param not in exp_params: + all_ok = False + out.write( + f"WARNING: {pfile} parameter" f" {param} is not expected\n" + ) + if not all_ok: + out.write(f"IGNORING CONTENTS OF {pfile}\n") + params = {} + elif params: + out.write(f"USING CUSTOMIZED PARAMETERS IN {pfile}\n") + return params + + +def _build_constraint_matrix(area, vardf, out, target_dir=None): + """ + Build constraint matrix B and target array from area targets CSV. + + Returns (B_csc, targets, target_labels, pop_share) where: + - B_csc is a sparse matrix (n_targets x n_records) + - targets is 1-D array of target values + - target_labels is a list of descriptive strings + - pop_share is the area's population share of national population + """ + if target_dir is None: + target_dir = STATE_TARGET_DIR + national_population = (vardf.s006 * vardf.XTOT).sum() + numobs = len(vardf) + targets_file = target_dir / f"{area}_targets.csv" + tdf = pd.read_csv(targets_file, comment="#") + + targets_list = [] + labels_list = [] + columns = [] + pop_share = None + + for row_idx, row in enumerate(tdf.itertuples(index=False)): + line = f"{area}:target{row_idx + 1}" + + # extract target value + target_val = row.target + targets_list.append(target_val) + + # build label + label = ( + f"{row.varname}" + f"/cnt={row.count}" + f"/scope={row.scope}" + f"/agi=[{row.agilo},{row.agihi})" + f"/fs={row.fstatus}" + ) + labels_list.append(label) + + # first row must be XTOT population target + if row_idx == 0: + assert row.varname == "XTOT", f"{line}: first target must be XTOT" + assert row.count == 0 and row.scope == 0 + assert row.agilo < -8e99 and row.agihi > 8e99 + assert row.fstatus == 0 + pop_share = row.target / national_population + out.write( + f"pop_share = {row.target:.0f}" + f" / {national_population:.0f}" + f" = {pop_share:.6f}\n" + ) + + # construct variable array for this target + assert 0 <= row.count <= 4, f"count {row.count} not in [0,4] on {line}" + if row.count == 0: + var_array = vardf[row.varname].astype(float).values + elif row.count == 1: + var_array = np.ones(numobs, dtype=float) + elif row.count == 2: + var_array = (vardf[row.varname] != 0).astype(float).values + elif row.count == 3: + var_array = (vardf[row.varname] > 0).astype(float).values + else: + var_array = (vardf[row.varname] < 0).astype(float).values + + # construct mask + mask = np.ones(numobs, dtype=float) + assert 0 <= row.scope <= 2, f"scope {row.scope} not in [0,2] on {line}" + if row.scope == 1: + mask *= (vardf.data_source == 1).values.astype(float) + elif row.scope == 2: + mask *= (vardf.data_source == 0).values.astype(float) + + in_agi_bin = (vardf.c00100 >= row.agilo) & (vardf.c00100 < row.agihi) + mask *= in_agi_bin.values.astype(float) + + assert ( + 0 <= row.fstatus <= 5 + ), f"fstatus {row.fstatus} not in [0,5] on {line}" + if row.fstatus > 0: + mask *= (vardf.MARS == row.fstatus).values.astype(float) + + # A[j,i] = mask * var_array (data values for this constraint) + columns.append(mask * var_array) + + assert pop_share is not None, "XTOT target not found" + + # A matrix: n_records x n_targets (each column is a constraint) + A_dense = np.column_stack(columns) + + # B = diag(w0) @ A where w0 = pop_share * s006 + w0 = pop_share * vardf.s006.values + B_dense = w0[:, np.newaxis] * A_dense + + # convert to sparse (n_targets x n_records) for Clarabel + B_csc = csc_matrix(B_dense.T) + + targets_arr = np.array(targets_list) + return B_csc, targets_arr, labels_list, pop_share + + +def _drop_impossible_targets(B_csc, targets, labels, out): + """Drop targets where all constraint matrix values are zero.""" + col_sums = np.abs(np.asarray(B_csc.sum(axis=1))).ravel() + all_zero = col_sums == 0 + if all_zero.any(): + n_drop = int(all_zero.sum()) + for i in np.where(all_zero)[0]: + out.write( + f"DROPPING impossible target" f" (all zeros): {labels[i]}\n" + ) + keep = ~all_zero + B_dense = B_csc.toarray() + B_csc = csc_matrix(B_dense[keep, :]) + targets = targets[keep] + labels = [lab for lab, k in zip(labels, keep) if k] + out.write( + f"Dropped {n_drop} impossible targets," + f" {len(targets)} remaining\n" + ) + return B_csc, targets, labels + + +def _solve_area_qp( # pylint: disable=unused-argument + B_csc, + targets, + labels, + n_records, + constraint_tol=AREA_CONSTRAINT_TOL, + slack_penalty=AREA_SLACK_PENALTY, + max_iter=AREA_MAX_ITER, + multiplier_min=AREA_MULTIPLIER_MIN, + multiplier_max=AREA_MULTIPLIER_MAX, + out=None, +): + """ + Solve the area reweighting QP using Clarabel. + + Returns (x_opt, s_lo, s_hi, info_dict). + """ + if out is None: + out = sys.stdout + + m = len(targets) + n_total = n_records + 2 * m # x + s_lo + s_hi + + # constraint bounds: t*(1-eps) <= Bx <= t*(1+eps) + abs_targets = np.abs(targets) + tol_band = abs_targets * constraint_tol + cl = targets - tol_band + cu = targets + tol_band + + # diagonal Hessian: 2 for x, 2*M for slacks + hess_diag = np.empty(n_total) + hess_diag[:n_records] = 2.0 + hess_diag[n_records:] = 2.0 * slack_penalty + P = spdiags(hess_diag, format="csc") + + # linear term: -2 for x (from (x-1)^2 = x^2 - 2x + 1) + q = np.zeros(n_total) + q[:n_records] = -2.0 + + # extended constraint matrix: [B | I_m | -I_m] + I_m = speye(m, format="csc") + A_full = hstack([B_csc, I_m, -I_m], format="csc") + + # constraint scaling for numerical conditioning + target_scale = np.maximum(np.abs(targets), 1.0) + D_inv = spdiags(1.0 / target_scale) + A_scaled = (D_inv @ A_full).tocsc() + cl_scaled = cl / target_scale + cu_scaled = cu / target_scale + + # variable bounds + var_lb = np.empty(n_total) + var_ub = np.empty(n_total) + var_lb[:n_records] = multiplier_min + var_ub[:n_records] = multiplier_max + var_lb[n_records:] = 0.0 + var_ub[n_records:] = 1e20 + + # Clarabel form: Ax + s = b, s in NonnegativeCone + I_n = speye(n_total, format="csc") + A_clar = vstack( + [A_scaled, -A_scaled, I_n, -I_n], + format="csc", + ) + b_clar = np.concatenate([cu_scaled, -cl_scaled, var_ub, -var_lb]) + + m_constraints = len(b_clar) + # pylint: disable=no-member + cones = [clarabel.NonnegativeConeT(m_constraints)] + + settings = clarabel.DefaultSettings() + # pylint: enable=no-member + settings.verbose = False + settings.max_iter = max_iter + settings.tol_gap_abs = 1e-7 + settings.tol_gap_rel = 1e-7 + settings.tol_feas = 1e-7 + + # solve + out.write("STARTING CLARABEL SOLVER...\n") + t_start = time.time() + solver = clarabel.DefaultSolver( # pylint: disable=no-member + P, q, A_clar, b_clar, cones, settings + ) + result = solver.solve() + elapsed = time.time() - t_start + + status_str = str(result.status) + out.write( + f"Solver status: {status_str}\n" + f"Iterations: {result.iterations}\n" + f"Solve time: {elapsed:.2f}s\n" + ) + + # extract solution + y_opt = np.array(result.x) + x_opt = y_opt[:n_records] + s_lo = y_opt[n_records : n_records + m] + s_hi = y_opt[n_records + m :] + + x_opt = np.clip(x_opt, multiplier_min, multiplier_max) + + info = { + "status": status_str, + "iterations": result.iterations, + "solve_time": elapsed, + "clarabel_solve_time": result.solve_time, + } + + return x_opt, s_lo, s_hi, info + + +def _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out +): + """Print target accuracy diagnostics.""" + achieved = np.asarray(B_csc @ x_opt).ravel() + abs_errors = np.abs(achieved - targets) + rel_errors = abs_errors / np.maximum(np.abs(targets), 1.0) + + out.write(f"TARGET ACCURACY ({len(targets)} targets):\n") + out.write(f" mean |relative error|: {rel_errors.mean():.6f}\n") + out.write(f" max |relative error|: {rel_errors.max():.6f}\n") + + eps = 1e-9 + n_violated = int((rel_errors > constraint_tol + eps).sum()) + n_hit = len(targets) - n_violated + out.write( + f" targets hit: {n_hit}/{len(targets)}" + f" (tolerance: +/-{constraint_tol * 100:.1f}% + eps)\n" + ) + if n_violated > 0: + out.write(f" VIOLATED: {n_violated} targets\n") + worst_idx = np.argsort(rel_errors)[::-1] + for idx in worst_idx[: min(10, n_violated)]: + out.write( + f" {rel_errors[idx] * 100:7.3f}%" + f" | target={targets[idx]:15.0f}" + f" | achieved={achieved[idx]:15.0f}" + f" | {labels[idx]}\n" + ) + return n_violated + + +def _print_multiplier_diagnostics(x_opt, out): + """Print weight multiplier distribution diagnostics.""" + out.write("MULTIPLIER DISTRIBUTION:\n") + out.write( + f" min={x_opt.min():.6f}," + f" p5={np.percentile(x_opt, 5):.6f}," + f" median={np.median(x_opt):.6f}," + f" p95={np.percentile(x_opt, 95):.6f}," + f" max={x_opt.max():.6f}\n" + ) + out.write( + f" RMSE from 1.0:" f" {np.sqrt(np.mean((x_opt - 1.0) ** 2)):.6f}\n" + ) + + # distribution bins + bins = [ + 0.0, + 1e-6, + 0.1, + 0.5, + 0.8, + 0.9, + 0.95, + 1.0, + 1.05, + 1.1, + 1.2, + 1.5, + 2.0, + 5.0, + 10.0, + 100.0, + np.inf, + ] + tot = len(x_opt) + out.write(f" distribution (n={tot}):\n") + for i in range(len(bins) - 1): + count = int(((x_opt >= bins[i]) & (x_opt < bins[i + 1])).sum()) + if count > 0: + out.write( + f" [{bins[i]:10.4f}," + f" {bins[i + 1]:10.4f}):" + f" {count:7d}" + f" ({count / tot:7.2%})\n" + ) + + +def _print_slack_diagnostics(s_lo, s_hi, targets, labels, out): + """Print elastic slack diagnostics.""" + total_slack = s_lo + s_hi + n_active = int(np.sum(total_slack > 1e-6)) + if n_active > 0: + out.write( + f"ELASTIC SLACK active on" + f" {n_active}/{len(targets)} constraints:\n" + ) + slack_idx = np.where(total_slack > 1e-6)[0] + for idx in slack_idx[np.argsort(total_slack[slack_idx])[::-1]][:20]: + out.write( + f" slack={total_slack[idx]:12.2f}" + f" | target={targets[idx]:15.0f}" + f" | {labels[idx]}\n" + ) + else: + out.write("ALL CONSTRAINTS SATISFIED WITHOUT SLACK\n") + + +def create_area_weights_file_clarabel( + area, + write_log=True, + write_file=True, + target_dir=None, + weight_dir=None, +): + """ + Create area weights file using Clarabel constrained QP solver. + + Drop-in replacement for create_area_weights_file() in + create_area_weights.py. + + Returns 0 on success. + """ + if target_dir is None: + target_dir = STATE_TARGET_DIR + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + # ensure output directory exists + weight_dir.mkdir(parents=True, exist_ok=True) + + # remove any existing output files + awpath = weight_dir / f"{area}_tmd_weights.csv.gz" + awpath.unlink(missing_ok=True) + logpath = weight_dir / f"{area}.log" + logpath.unlink(missing_ok=True) + + # set up output + if write_log: + out = open( # pylint: disable=consider-using-with + logpath, "w", encoding="utf-8" + ) + else: + out = sys.stdout + + if write_file: + out.write( + f"CREATING WEIGHTS FILE FOR AREA {area}" " (Clarabel solver) ...\n" + ) + else: + out.write( + f"DOING JUST CALCS FOR AREA {area}" " (Clarabel solver) ...\n" + ) + + # read optional parameters + params = _read_params(area, out, target_dir=target_dir) + constraint_tol = params.get( + "constraint_tol", + params.get("target_ratio_tolerance", AREA_CONSTRAINT_TOL), + ) + slack_penalty = params.get("slack_penalty", AREA_SLACK_PENALTY) + max_iter = params.get("max_iter", AREA_MAX_ITER) + multiplier_min = params.get("multiplier_min", AREA_MULTIPLIER_MIN) + multiplier_max = params.get("multiplier_max", AREA_MULTIPLIER_MAX) + + # load data and build constraint matrix + vdf = _load_taxcalc_data() + out.write(f"Loaded {len(vdf)} records\n") + out.write(f"National weight sum: {vdf.s006.sum():.0f}\n") + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=target_dir + ) + out.write( + f"Built constraint matrix:" + f" {B_csc.shape[0]} targets x" + f" {B_csc.shape[1]} records\n" + ) + out.write(f"Constraint tolerance: +/-{constraint_tol * 100:.1f}%\n") + out.write(f"Multiplier bounds: [{multiplier_min}, {multiplier_max}]\n") + + # drop impossible targets + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + # check what x=1 (population-proportional) achieves + n_records = B_csc.shape[1] + x_ones = np.ones(n_records) + achieved_x1 = np.asarray(B_csc @ x_ones).ravel() + rel_err_x1 = np.abs(achieved_x1 - targets) / np.maximum( + np.abs(targets), 1.0 + ) + out.write("BEFORE OPTIMIZATION (x=1, population-proportional):\n") + out.write(f" mean |relative error|: {rel_err_x1.mean():.6f}\n") + out.write(f" max |relative error|: {rel_err_x1.max():.6f}\n") + + # solve QP + x_opt, s_lo, s_hi, _info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=constraint_tol, + slack_penalty=slack_penalty, + max_iter=max_iter, + multiplier_min=multiplier_min, + multiplier_max=multiplier_max, + out=out, + ) + + # diagnostics + _n_violated = _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out + ) + _print_multiplier_diagnostics(x_opt, out) + _print_slack_diagnostics(s_lo, s_hi, targets, labels, out) + + if write_log: + out.close() + if not write_file: + return 0 + + # write area weights file with population extrapolation + w0 = pop_share * vdf.s006.values + wght_area = x_opt * w0 + + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + pop = yaml.safe_load(pf.read()) + + wdict = {f"WT{FIRST_YEAR}": wght_area} + cum_pop_growth = 1.0 + for year in range(FIRST_YEAR + 1, LAST_YEAR + 1): + annual_pop_growth = pop[year] / pop[year - 1] + cum_pop_growth *= annual_pop_growth + wdict[f"WT{year}"] = wght_area * cum_pop_growth + + wdf = pd.DataFrame.from_dict(wdict) + wdf.to_csv( + awpath, + index=False, + float_format="%.5f", + compression="gzip", + ) + + return 0 diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py new file mode 100644 index 00000000..3ac2064b --- /dev/null +++ b/tmd/areas/quality_report.py @@ -0,0 +1,589 @@ +# pylint: disable=import-outside-toplevel,inconsistent-quotes +""" +Cross-state quality summary report. + +Parses solver logs for all states and produces a summary showing: + - Solve status and timing + - Target accuracy (hit rate, mean/max error) + - Weight distortion (RMSE, percentiles) + - Violated targets by variable + - Weight exhaustion and cross-state aggregation diagnostics + +Usage: + python -m tmd.areas.quality_report + python -m tmd.areas.quality_report --scope CA,WY +""" + +import argparse +import re +from pathlib import Path + +import numpy as np +import pandas as pd + +from tmd.areas.create_area_weights_clarabel import ( + AREA_CONSTRAINT_TOL, + STATE_WEIGHT_DIR, +) +from tmd.areas.prepare.constants import ALL_STATES +from tmd.imputation_assumptions import TAXYEAR + +_WT_COL = f"WT{TAXYEAR}" + +# Decode raw constraint descriptions into human-readable labels +_CNT_LABELS = {0: "amt", 1: "returns", 2: "nz-count"} +_FS_LABELS = {0: "all", 1: "single", 2: "MFJ", 4: "HoH"} + + +def _humanize_desc(desc: str) -> str: + """ + Turn 'c00100/cnt=1/scope=1/agi=[500000.0,1000000.0)/fs=4' + into 'c00100 returns HoH $500K-$1M'. + """ + parts = desc.split("/") + varname = parts[0] + attrs = {} + for p in parts[1:]: + if "=" in p: + k, v = p.split("=", 1) + attrs[k] = v + + cnt = int(attrs.get("cnt", -1)) + fs = int(attrs.get("fs", 0)) + agi_raw = attrs.get("agi", "") + + cnt_label = _CNT_LABELS.get(cnt, f"cnt{cnt}") + fs_label = _FS_LABELS.get(fs, f"fs{fs}") + + # Parse AGI range like [500000.0,1000000.0) + agi_label = "" + m = re.match(r"\[([^,]+),([^)]+)\)", agi_raw) + if m: + lo_s, hi_s = m.group(1), m.group(2) + lo = float(lo_s) + hi = float(hi_s) + if lo < -1e10: + agi_label = f"<${hi / 1000:.0f}K" + elif hi > 1e10: + agi_label = f"${lo / 1000:.0f}K+" + else: + agi_label = f"${lo / 1000:.0f}K-${hi / 1000:.0f}K" + + pieces = [varname, cnt_label] + if fs != 0: + pieces.append(fs_label) + if agi_label: + pieces.append(agi_label) + return " ".join(pieces) + + +def parse_log(logpath: Path) -> dict: + """Parse a single area solver log file into a summary dict.""" + if not logpath.exists(): + return {"status": "NO LOG"} + log = logpath.read_text() + + result = {"status": "UNKNOWN"} + + # Solve status + m = re.search(r"Solver status: (\S+)", log) + if m: + result["status"] = m.group(1) + if "PrimalInfeasible" in log or "FAILED" in log: + result["status"] = "FAILED" + + # Solve time + m = re.search(r"Solve time: ([\d.]+)s", log) + if m: + result["solve_time"] = float(m.group(1)) + + # Target accuracy + m = re.search(r"mean \|relative error\|: ([\d.]+)", log) + if m: + result["mean_err"] = float(m.group(1)) + m = re.search(r"max \|relative error\|: ([\d.]+)", log) + if m: + result["max_err"] = float(m.group(1)) + m = re.search(r"targets hit: (\d+)/(\d+)", log) + if m: + result["targets_hit"] = int(m.group(1)) + result["targets_total"] = int(m.group(2)) + m_viol = re.search(r"VIOLATED: (\d+) targets", log) + result["n_violated"] = int(m_viol.group(1)) if m_viol else 0 + + # Weight distortion + m = re.search( + r"min=([\d.]+), p5=([\d.]+), median=([\d.]+), " + r"p95=([\d.]+), max=([\d.]+)", + log, + ) + if m: + result["w_min"] = float(m.group(1)) + result["w_p5"] = float(m.group(2)) + result["w_median"] = float(m.group(3)) + result["w_p95"] = float(m.group(4)) + result["w_max"] = float(m.group(5)) + m = re.search(r"RMSE from 1.0: ([\d.]+)", log) + if m: + result["w_rmse"] = float(m.group(1)) + + # Weight distribution histogram + dist_bins = {} + for line in log.splitlines(): + m_bin = re.match( + r"\s+\[\s*([\d.]+),\s*([\d.]+)\):\s+(\d+)\s+\(\s*([\d.]+)%\)", + line, + ) + if m_bin: + lo, hi = float(m_bin.group(1)), float(m_bin.group(2)) + cnt, pct = int(m_bin.group(3)), float(m_bin.group(4)) + dist_bins[(lo, hi)] = {"count": cnt, "pct": pct} + result["dist_bins"] = dist_bins + + m_n = re.search(r"distribution \(n=(\d+)\)", log) + if m_n: + result["n_records"] = int(m_n.group(1)) + + # Violated target details + violated = [] + in_violated = False + for line in log.splitlines(): + if "VIOLATED:" in line and "targets" in line: + in_violated = True + continue + if in_violated: + m_det = re.match( + r"\s+([\d.]+)%\s*\|\s*target=\s*([\d.]+)\s*\|" + r"\s*achieved=\s*([\d.]+)\s*\|" + r"\s*(\S+/cnt=\d+/scope=\d+/agi=.*?/fs=\d+)", + line, + ) + if m_det: + violated.append( + { + "pct_err": float(m_det.group(1)), + "target": float(m_det.group(2)), + "achieved": float(m_det.group(3)), + "desc": m_det.group(4), + } + ) + else: + in_violated = False + result["violated_details"] = violated + + return result + + +def generate_report(areas=None, weight_dir=None): + """Generate cross-state quality summary report.""" + if areas is None: + areas = ALL_STATES + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + + rows = [] + for st in areas: + logpath = weight_dir / f"{st.lower()}.log" + info = parse_log(logpath) + info["state"] = st + rows.append(info) + + df = pd.DataFrame(rows) + + # Summary statistics + solved = df[df["status"].isin(["Solved", "AlmostSolved"])] + failed = df[df["status"] == "FAILED"] + n_states = len(df) + n_solved = len(solved) + n_failed = len(failed) + n_violated_states = (solved["n_violated"] > 0).sum() + total_violated = solved["n_violated"].sum() + + tol_pct = AREA_CONSTRAINT_TOL * 100 + + lines = [] + lines.append("=" * 80) + lines.append("CROSS-STATE QUALITY SUMMARY REPORT") + lines.append("=" * 80) + lines.append("") + + # Overall + lines.append(f"States: {n_states}") + lines.append(f"Solved: {n_solved}") + lines.append(f"Failed: {n_failed}") + if n_failed > 0: + lines.append(f" Failed: {', '.join(failed['state'].tolist())}") + lines.append( + f"States with violated targets: {n_violated_states}/{n_solved}" + ) + if not solved.empty and "targets_total" in solved.columns: + tpt = int(solved["targets_total"].iloc[0]) + tpt_sum = int(solved["targets_total"].sum()) + else: + tpt, tpt_sum = "?", "?" + lines.append(f"Total targets: {n_solved} states \u00d7 {tpt} = {tpt_sum}") + lines.append(f"Total violated targets: {int(total_violated)}") + lines.append("") + + # Target accuracy + if not solved.empty and "mean_err" in solved.columns: + lines.append("TARGET ACCURACY:") + lines.append( + f" Per-state mean error: " + f"avg across states={solved['mean_err'].mean():.4f}, " + f"worst state={solved['mean_err'].max():.4f}" + ) + lines.append( + f" Per-state max error: " + f"avg across states={solved['max_err'].mean():.4f}, " + f"worst state={solved['max_err'].max():.4f}" + ) + if "targets_hit" in solved.columns: + total_t = solved["targets_total"].iloc[0] + hit_pcts = solved["targets_hit"] / solved["targets_total"] * 100 + lines.append( + f" Hit rate: " + f"avg={hit_pcts.mean():.1f}%, " + f"min={hit_pcts.min():.1f}% " + f"(out of {total_t} targets, " + f"tolerance: +/-{tol_pct:.1f}% + eps)" + ) + lines.append("") + + # Weight distortion + if not solved.empty and "w_rmse" in solved.columns: + lines.append("WEIGHT DISTORTION (multiplier from 1.0):") + lines.append( + f" RMSE: avg={solved['w_rmse'].mean():.3f}, " + f"max={solved['w_rmse'].max():.3f}" + ) + lines.append( + f" Min: avg={solved['w_min'].mean():.3f}, " + f"min={solved['w_min'].min():.3f}" + ) + lines.append( + f" P05: avg={solved['w_p5'].mean():.3f}, " + f"min={solved['w_p5'].min():.3f}" + ) + lines.append( + f" Median: avg={solved['w_median'].mean():.3f}, " + f"range=[{solved['w_median'].min():.3f}, " + f"{solved['w_median'].max():.3f}]" + ) + lines.append( + f" P95: avg={solved['w_p95'].mean():.3f}, " + f"max={solved['w_p95'].max():.3f}" + ) + lines.append( + f" Max: avg={solved['w_max'].mean():.1f}, " + f"max={solved['w_max'].max():.1f}" + ) + lines.append("") + + # Near-zero weight summary + if not solved.empty: + zero_pcts = [] + lt01_pcts = [] + for _, row in solved.iterrows(): + dist = row.get("dist_bins", {}) + n_rec = row.get("n_records", 0) + if not dist or n_rec == 0: + continue + n_zero = dist.get((0.0, 0.0), {}).get("count", 0) + n_lt01 = n_zero + dist.get((0.0, 0.1), {}).get("count", 0) + zero_pcts.append(100 * n_zero / n_rec) + lt01_pcts.append(100 * n_lt01 / n_rec) + if zero_pcts: + lines.append("NEAR-ZERO WEIGHT MULTIPLIERS (% of records):") + lines.append( + f" Exact zero (x=0): " + f"avg={np.mean(zero_pcts):.1f}%, " + f"max={np.max(zero_pcts):.1f}%" + ) + lines.append( + f" Below 0.1 (x<0.1): " + f"avg={np.mean(lt01_pcts):.1f}%, " + f"max={np.max(lt01_pcts):.1f}%" + ) + lines.append("") + + # Per-state table + lines.append("PER-STATE DETAIL:") + lines.append( + " Err cols = |relative error| (fraction); " + "weight cols = multiplier on national weight (1.0 = unchanged)" + ) + header = ( + f"{'St':<4} {'Status':<14} {'Hit':>5} {'Tot':>5} " + f"{'Viol':>5} {'MeanErr':>8} {'MaxErr':>8} " + f"{'wRMSE':>7} {'wP05':>7} {'wMed':>7} " + f"{'wP95':>7} {'wMax':>8} {'%zero':>6}" + ) + lines.append(header) + lines.append("-" * len(header)) + for _, row in df.iterrows(): + hit = int(row.get("targets_hit", 0)) + tot = int(row.get("targets_total", 0)) + viol = int(row.get("n_violated", 0)) + me = row.get("mean_err", 0) + mx = row.get("max_err", 0) + rmse = row.get("w_rmse", 0) + p5 = row.get("w_p5", 0) + med = row.get("w_median", 0) + p95 = row.get("w_p95", 0) + wmax = row.get("w_max", 0) + dist = row.get("dist_bins", {}) + n_rec = row.get("n_records", 0) + n_zero = 0 + if isinstance(dist, dict): + n_zero = dist.get((0.0, 0.0), {}).get("count", 0) + pct_zero = 100 * n_zero / n_rec if n_rec > 0 else 0 + lines.append( + f"{row['state']:<4} {row['status']:<14} {hit:>5} {tot:>5} " + f"{viol:>5} {me:>8.4f} {mx:>8.4f} " + f"{rmse:>7.3f} {p5:>7.3f} {med:>7.3f} " + f"{p95:>7.3f} {wmax:>8.1f} {pct_zero:>5.1f}%" + ) + lines.append("") + + # Violated targets by variable + all_violated = [] + for _, row in df.iterrows(): + for v in row.get("violated_details", []): + desc = v["desc"] + varname = desc.split("/")[0] + cnt_m = re.search(r"cnt=(\d+)", desc) + cnt_type = int(cnt_m.group(1)) if cnt_m else -1 + abs_miss = abs(v["achieved"] - v["target"]) + all_violated.append( + { + "state": row["state"], + "varname": varname, + "cnt_type": cnt_type, + "pct_err": v["pct_err"], + "target": v["target"], + "achieved": v["achieved"], + "abs_miss": abs_miss, + "desc": desc, + } + ) + if all_violated: + vdf = pd.DataFrame(all_violated) + var_counts = vdf["varname"].value_counts() + lines.append("VIOLATIONS BY VARIABLE:") + for var, cnt in var_counts.items(): + states_with = sorted(vdf[vdf["varname"] == var]["state"].unique()) + lines.append( + f" {var}: {cnt} violations across " + f"{len(states_with)} states" + ) + lines.append("") + + state_counts = vdf["state"].value_counts().head(10) + lines.append("STATES WITH MOST VIOLATIONS:") + for st, cnt in state_counts.items(): + lines.append(f" {st}: {cnt} violated") + lines.append("") + + amt_viol = vdf[vdf["cnt_type"] == 0].sort_values( + ["pct_err", "abs_miss"], ascending=[False, False] + ) + lines.append("WORST 5 AMOUNT VIOLATIONS (by % error):") + if amt_viol.empty: + lines.append(" (none — all amount targets met)") + else: + for _, r in amt_viol.head(5).iterrows(): + lines.append( + f" {r['state']:<4} {r['pct_err']:.3f}% " + f"target=${r['target']:>15,.0f} " + f"achieved=${r['achieved']:>15,.0f} " + f"miss=${r['abs_miss']:>12,.0f} " + f"{_humanize_desc(r['desc'])}" + ) + lines.append("") + + cnt_viol = vdf[vdf["cnt_type"].isin([1, 2])].sort_values( + ["pct_err", "abs_miss"], ascending=[False, False] + ) + lines.append("WORST 5 COUNT VIOLATIONS (by % error):") + if cnt_viol.empty: + lines.append(" (none — all count targets met)") + else: + for _, r in cnt_viol.head(5).iterrows(): + lines.append( + f" {r['state']:<4} {r['pct_err']:.3f}% " + f"target={r['target']:>12,.0f} " + f"achieved={r['achieved']:>12,.0f} " + f"miss={r['abs_miss']:>8,.0f} " + f"{_humanize_desc(r['desc'])}" + ) + lines.append("") + + # Weight diagnostics + lines.extend(_weight_diagnostics(areas, weight_dir)) + + report = "\n".join(lines) + return report + + +def _weight_diagnostics(areas, weight_dir=None): + """ + Combined weight diagnostics: exhaustion + national aggregation. + + Loads TMD data and state weight files once, reuses for both. + Only reads the specific columns needed (not the full allvars). + """ + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + + from tmd.storage import STORAGE_FOLDER + + lines = [] + + # Load national data (only needed columns) + tmd_path = STORAGE_FOLDER / "output" / "tmd.csv.gz" + tmd_cols = ["s006", "e00200", "p22250", "p23250"] + tmd = pd.read_csv(tmd_path, usecols=tmd_cols) + tmd["capgains_net"] = tmd["p22250"] + tmd["p23250"] + n_records = len(tmd) + s006 = tmd["s006"].values + + agi_path = STORAGE_FOLDER / "output" / "cached_c00100.npy" + if agi_path.exists(): + tmd["c00100"] = np.load(agi_path) + + allvars_path = STORAGE_FOLDER / "output" / "cached_allvars.csv" + if allvars_path.exists(): + needed_tc = ["c18300", "iitax"] + allvars = pd.read_csv(allvars_path, usecols=needed_tc) + for col in needed_tc: + if col in allvars.columns: + tmd[col] = allvars[col].values + + # Load all state weights once + weight_sum = np.zeros(n_records) + state_weights = {} + for st in areas: + wpath = weight_dir / f"{st.lower()}_tmd_weights.csv.gz" + if not wpath.exists(): + continue + w = pd.read_csv(wpath, usecols=[_WT_COL])[_WT_COL].values + weight_sum += w + state_weights[st] = w + + n_loaded = len(state_weights) + if n_loaded == 0: + return lines + + # Weight exhaustion + usage = weight_sum / s006 + + _sub = "sum of state weights / national weight" + lines.append(f"WEIGHT EXHAUSTION ({_sub}):") + lines.append( + f" A ratio of 1.0 means the record's national" + f" weight is fully allocated across" + f" {n_loaded} states." + ) + pcts = [0, 1, 5, 10, 25, 50, 75, 90, 95, 99, 100] + quantiles = np.percentile(usage, pcts) + parts = [] + for p, q in zip(pcts, quantiles): + label = {0: "min", 50: "median", 100: "max"}.get(p, f"p{p}") + parts.append(f"{label}={q:.4f}") + lines.append(" " + ", ".join(parts)) + lines.append(f" Mean: {usage.mean():.4f}, Std: {usage.std():.4f}") + + n_over = int((usage > 1.10).sum()) + n_under = int((usage < 0.90).sum()) + lines.append( + f" Over-used (>1.10):" + f" {n_over} ({100 * n_over / n_records:.1f}%) " + f"Under-used (<0.90):" + f" {n_under} ({100 * n_under / n_records:.1f}%)" + ) + lines.append("") + + # Cross-state aggregation vs national totals + check_vars = [ + ("Returns (s006)", "s006", True), + ("AGI (c00100)", "c00100", False), + ("Wages (e00200)", "e00200", False), + ("Capital gains (capgains_net)", "capgains_net", False), + ("SALT ded (c18300)", "c18300", False), + ("Income tax (iitax)", "iitax", False), + ] + + national = {} + for label, var, is_count in check_vars: + if var not in tmd.columns: + national[var] = None + continue + if var == "s006": + national[var] = float(s006.sum()) + else: + national[var] = float((s006 * tmd[var].values).sum()) + + state_sums = {var: 0.0 for _, var, _ in check_vars} + for _st, w in state_weights.items(): + for _label, var, _is_count in check_vars: + if var not in tmd.columns: + continue + if var == "s006": + state_sums[var] += float(w.sum()) + else: + state_sums[var] += float((w * tmd[var].values).sum()) + + lines.append( + f"CROSS-STATE AGGREGATION vs NATIONAL TOTALS" + f" for SELECTED VARIABLES ({n_loaded} states):" + ) + lines.append( + f" {'Variable':<30} {'National':>16}" + f" {'Sum-of-States':>16} {'Diff%':>8}" + ) + lines.append(" " + "-" * 72) + for label, var, is_count in check_vars: + nat = national[var] + sos = state_sums[var] + if nat is None or nat == 0: + continue + diff_pct = (sos / nat - 1) * 100 + if is_count: + lines.append( + f" {label:<30} {nat:>16,.0f}" + f" {sos:>16,.0f} {diff_pct:>+7.2f}%" + ) + else: + lines.append( + f" {label:<30}" + f" ${nat / 1e9:>14.1f}B" + f" ${sos / 1e9:>14.1f}B" + f" {diff_pct:>+7.2f}%" + ) + lines.append("") + + return lines + + +def main(): + parser = argparse.ArgumentParser( + description="Cross-state quality summary report", + ) + parser.add_argument( + "--scope", + default=None, + help="Comma-separated state codes (default: all states)", + ) + args = parser.parse_args() + + areas = None + if args.scope: + areas = [s.strip().upper() for s in args.scope.split(",")] + + report = generate_report(areas) + print(report) + + +if __name__ == "__main__": + main() diff --git a/tmd/areas/solve_weights.py b/tmd/areas/solve_weights.py new file mode 100644 index 00000000..9de27f87 --- /dev/null +++ b/tmd/areas/solve_weights.py @@ -0,0 +1,112 @@ +# pylint: disable=import-outside-toplevel +""" +Solve for state weights using Clarabel QP optimizer. + +Reads per-state target CSV files (produced by prepare_targets.py) +and runs the Clarabel constrained QP solver to find weight +multipliers that hit area-specific targets within tolerance. + +Usage: + # All states, 8 parallel workers: + python -m tmd.areas.solve_weights --scope states --workers 8 + + # Specific states: + python -m tmd.areas.solve_weights --scope MN,CA,TX --workers 4 + + # Force recompute even if up-to-date: + python -m tmd.areas.solve_weights --scope states --workers 8 --force + + # Single state, no parallelism: + python -m tmd.areas.solve_weights --scope MN +""" + +import argparse +import time + +from tmd.areas.create_area_weights_clarabel import ( + STATE_TARGET_DIR, + STATE_WEIGHT_DIR, +) + + +def solve_state_weights( + scope="states", + num_workers=1, + force=True, +): + """ + Run the Clarabel solver for the specified areas. + + Parameters + ---------- + scope : str + 'states' or comma-separated state codes (e.g., 'MN,CA,TX'). + num_workers : int + Number of parallel worker processes. + force : bool + Recompute all areas even if weight files are up-to-date. + """ + from tmd.areas.batch_weights import run_batch + + specific = _parse_scope(scope) + if specific: + area_filter = ",".join(a.lower() for a in specific) + else: + area_filter = "states" + + t0 = time.time() + print("Solving state weights...") + run_batch( + num_workers=num_workers, + area_filter=area_filter, + force=force, + target_dir=STATE_TARGET_DIR, + weight_dir=STATE_WEIGHT_DIR, + ) + elapsed = time.time() - t0 + print(f"Total solve time: {elapsed:.1f}s") + + +def _parse_scope(scope): + """Parse scope string into a list of state codes or None for all.""" + _EXCLUDE = {"US", "PR", "OA"} + scope_lower = scope.lower().strip() + if scope_lower in ("states", "all"): + return None + codes = [c.strip().upper() for c in scope.split(",") if c.strip()] + return [c for c in codes if len(c) == 2 and c not in _EXCLUDE] + + +def main(): + """CLI entry point.""" + parser = argparse.ArgumentParser( + description="Solve for state weights using Clarabel QP optimizer", + ) + parser.add_argument( + "--scope", + default="states", + help="'states' or comma-separated state codes (e.g., 'MN,CA,TX')", + ) + parser.add_argument( + "--workers", + type=int, + default=1, + help="Number of parallel solver workers (default: 1)", + ) + parser.add_argument( + "--force", + action="store_true", + default=True, + help="Recompute all areas even if up-to-date (default: True)", + ) + args = parser.parse_args() + + solve_state_weights( + scope=args.scope, + num_workers=args.workers, + force=args.force, + ) + + +if __name__ == "__main__": + main() From 0291e5fd8afa16e8ec0b9d1e1f13e5f6f5446846 Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 11:08:53 -0400 Subject: [PATCH 02/10] Show target count per area in batch solver progress message --- tmd/areas/batch_weights.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tmd/areas/batch_weights.py b/tmd/areas/batch_weights.py index 9732da3f..b023e881 100644 --- a/tmd/areas/batch_weights.py +++ b/tmd/areas/batch_weights.py @@ -259,8 +259,22 @@ def run_batch( print("All areas up-to-date. Use --force to recompute.") return + # Count targets from first area's CSV + first_tpath = target_dir / f"{areas[0]}_targets.csv" + n_targets = ( + sum( + 1 + for line in first_tpath.read_text().splitlines() + if line.strip() and not line.startswith("#") + ) + - 1 + ) # subtract header n = len(areas) - print(f"Processing {n} areas with {num_workers} workers...") + print( + f"Processing {n} areas" + f" (up to {n_targets} targets each)" + f" with {num_workers} workers..." + ) print( "(Areas shown in completion order, which varies with" " parallel workers.)" From bdc14d5993ea90ae29b65cb0216d05ce7f8f088c Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 12:22:49 -0400 Subject: [PATCH 03/10] Add exhaustion limiting and parameter sweep capability - Add weight_penalty parameter to QP solver for controlling multiplier-vs-constraint tradeoff - Add --max-exhaustion flag to solve_weights for iterative two-pass exhaustion limiting with per-record multiplier caps - Enhance quality report with exhaustion record profiles (top 5 most exhausted records with taxpayer characteristics) - Add sweep_params.py for grid search over multiplier_max and weight_penalty combinations Co-Authored-By: Claude Opus 4.6 (1M context) --- tmd/areas/batch_weights.py | 12 + tmd/areas/create_area_weights_clarabel.py | 16 +- tmd/areas/quality_report.py | 55 ++++- tmd/areas/solve_weights.py | 208 +++++++++++++++++- tmd/areas/sweep_params.py | 256 ++++++++++++++++++++++ 5 files changed, 529 insertions(+), 18 deletions(-) create mode 100644 tmd/areas/sweep_params.py diff --git a/tmd/areas/batch_weights.py b/tmd/areas/batch_weights.py index b023e881..aa3730b4 100644 --- a/tmd/areas/batch_weights.py +++ b/tmd/areas/batch_weights.py @@ -101,6 +101,18 @@ def _solve_one_area(area): ) n_records = B_csc.shape[1] + + # Check for per-record multiplier caps (from exhaustion limiting) + caps_path = wgt_dir / f"{area}_record_caps.npy" + if caps_path.exists(): + record_caps = np.load(caps_path) + multiplier_max = np.minimum(multiplier_max, record_caps) + n_capped = int((record_caps < AREA_MULTIPLIER_MAX).sum()) + out.write( + f"USING PER-RECORD MULTIPLIER CAPS" + f" ({n_capped} records capped)\n" + ) + x_opt, s_lo, s_hi, info = _solve_area_qp( B_csc, targets, diff --git a/tmd/areas/create_area_weights_clarabel.py b/tmd/areas/create_area_weights_clarabel.py index 40a1d4c5..15878fed 100644 --- a/tmd/areas/create_area_weights_clarabel.py +++ b/tmd/areas/create_area_weights_clarabel.py @@ -261,11 +261,19 @@ def _solve_area_qp( # pylint: disable=unused-argument max_iter=AREA_MAX_ITER, multiplier_min=AREA_MULTIPLIER_MIN, multiplier_max=AREA_MULTIPLIER_MAX, + weight_penalty=1.0, out=None, ): """ Solve the area reweighting QP using Clarabel. + Parameters + ---------- + weight_penalty : float + Penalty weight on (x-1)^2 relative to constraint + slack. Higher values keep multipliers closer to 1.0 + at the cost of more target violations. + Returns (x_opt, s_lo, s_hi, info_dict). """ if out is None: @@ -280,15 +288,15 @@ def _solve_area_qp( # pylint: disable=unused-argument cl = targets - tol_band cu = targets + tol_band - # diagonal Hessian: 2 for x, 2*M for slacks + # diagonal Hessian: 2*alpha for x, 2*M for slacks hess_diag = np.empty(n_total) - hess_diag[:n_records] = 2.0 + hess_diag[:n_records] = 2.0 * weight_penalty hess_diag[n_records:] = 2.0 * slack_penalty P = spdiags(hess_diag, format="csc") - # linear term: -2 for x (from (x-1)^2 = x^2 - 2x + 1) + # linear term: -2*alpha for x q = np.zeros(n_total) - q[:n_records] = -2.0 + q[:n_records] = -2.0 * weight_penalty # extended constraint matrix: [B | I_m | -I_m] I_m = speye(m, format="csc") diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py index 3ac2064b..3d2c2ae6 100644 --- a/tmd/areas/quality_report.py +++ b/tmd/areas/quality_report.py @@ -440,9 +440,20 @@ def _weight_diagnostics(areas, weight_dir=None): lines = [] - # Load national data (only needed columns) + # Load national data tmd_path = STORAGE_FOLDER / "output" / "tmd.csv.gz" - tmd_cols = ["s006", "e00200", "p22250", "p23250"] + tmd_cols = [ + "s006", + "MARS", + "XTOT", + "data_source", + "e00200", + "e00300", + "e00600", + "e26270", + "p22250", + "p23250", + ] tmd = pd.read_csv(tmd_path, usecols=tmd_cols) tmd["capgains_net"] = tmd["p22250"] + tmd["p23250"] n_records = len(tmd) @@ -502,6 +513,46 @@ def _weight_diagnostics(areas, weight_dir=None): f"Under-used (<0.90):" f" {n_under} ({100 * n_under / n_records:.1f}%)" ) + for thresh in [2, 5, 10]: + ct = int((usage > thresh).sum()) + if ct > 0: + lines.append(f" Exhaustion > {thresh}x: {ct} records") + lines.append("") + + # Most exhausted records — profile and top states + _mars = {1: "Single", 2: "MFJ", 3: "MFS", 4: "HoH", 5: "Wid"} + _ds = {0: "CPS", 1: "PUF"} + top_idx = np.argsort(usage)[::-1][:5] + lines.append("MOST EXHAUSTED RECORDS (top 5):") + for rank, idx in enumerate(top_idx, 1): + r = tmd.iloc[idx] + exh = usage[idx] + mars = _mars.get(int(r.MARS), f"fs{int(r.MARS)}") + ds = _ds.get(int(r.data_source), "?") + agi = tmd["c00100"].iloc[idx] if "c00100" in tmd else 0 + # Top 3 states by weight for this record + st_wts = [] + for st, w in state_weights.items(): + if w[idx] > 0: + st_wts.append((st, w[idx])) + st_wts.sort(key=lambda x: -x[1]) + top3 = ", ".join(f"{st}={wt:.1f}" for st, wt in st_wts[:3]) + lines.append( + f" {rank}. rec {idx}: exh={exh:.1f}x," + f" s006={r.s006:.1f}," + f" {ds} {mars}," + f" AGI=${agi:,.0f}" + ) + lines.append( + f" wages=${r.e00200:,.0f}," + f" int=${r.e00300:,.0f}," + f" div=${r.e00600:,.0f}," + f" ptshp=${r.e26270:,.0f}" + ) + lines.append( + f" top states: {top3}" + f" ({len(st_wts)} nonzero of {n_loaded})" + ) lines.append("") # Cross-state aggregation vs national totals diff --git a/tmd/areas/solve_weights.py b/tmd/areas/solve_weights.py index 9de27f87..17e0bb28 100644 --- a/tmd/areas/solve_weights.py +++ b/tmd/areas/solve_weights.py @@ -6,33 +6,44 @@ and runs the Clarabel constrained QP solver to find weight multipliers that hit area-specific targets within tolerance. +Optional exhaustion limiting (--max-exhaustion) runs a two-pass +solve: first unconstrained, then with per-record multiplier caps +to keep cross-state weight exhaustion within bounds. + Usage: # All states, 8 parallel workers: python -m tmd.areas.solve_weights --scope states --workers 8 + # With exhaustion cap of 5x: + python -m tmd.areas.solve_weights --scope states --workers 8 \ + --max-exhaustion 5 + # Specific states: python -m tmd.areas.solve_weights --scope MN,CA,TX --workers 4 - - # Force recompute even if up-to-date: - python -m tmd.areas.solve_weights --scope states --workers 8 --force - - # Single state, no parallelism: - python -m tmd.areas.solve_weights --scope MN """ import argparse import time +import numpy as np +import pandas as pd + from tmd.areas.create_area_weights_clarabel import ( + AREA_MULTIPLIER_MAX, STATE_TARGET_DIR, STATE_WEIGHT_DIR, ) +from tmd.imputation_assumptions import TAXYEAR + +_WT_COL = f"WT{TAXYEAR}" +_MAX_EXHAUST_ITERATIONS = 5 def solve_state_weights( scope="states", num_workers=1, force=True, + max_exhaustion=None, ): """ Run the Clarabel solver for the specified areas. @@ -40,11 +51,15 @@ def solve_state_weights( Parameters ---------- scope : str - 'states' or comma-separated state codes (e.g., 'MN,CA,TX'). + 'states' or comma-separated state codes. num_workers : int Number of parallel worker processes. force : bool Recompute all areas even if weight files are up-to-date. + max_exhaustion : float or None + If set, limit per-record cross-state weight exhaustion + to this multiple of the national weight. Runs iterative + two-pass solve. """ from tmd.areas.batch_weights import run_batch @@ -55,7 +70,9 @@ def solve_state_weights( area_filter = "states" t0 = time.time() - print("Solving state weights...") + + # --- Pass 1: unconstrained solve --- + print("Pass 1: solving state weights...") run_batch( num_workers=num_workers, area_filter=area_filter, @@ -63,12 +80,169 @@ def solve_state_weights( target_dir=STATE_TARGET_DIR, weight_dir=STATE_WEIGHT_DIR, ) + + if max_exhaustion is None: + elapsed = time.time() - t0 + print(f"Total solve time: {elapsed:.1f}s") + return + + # --- Exhaustion-limited iterative passes --- + for iteration in range(1, _MAX_EXHAUST_ITERATIONS + 1): + exhaustion, state_weights = _compute_exhaustion(STATE_WEIGHT_DIR) + n_over = int((exhaustion > max_exhaustion).sum()) + max_exh = exhaustion.max() + print( + f"\nExhaustion check (pass {iteration}):" + f" max={max_exh:.2f}," + f" {n_over} records > {max_exhaustion}x" + ) + if n_over == 0: + print("All records within exhaustion limit.") + break + + # Compute and write per-record caps + affected = _write_exhaustion_caps( + exhaustion, + state_weights, + max_exhaustion, + STATE_WEIGHT_DIR, + STATE_TARGET_DIR, + ) + if not affected: + break + + # Re-solve affected states + af = ",".join(affected) + print( + f"Pass {iteration + 1}: re-solving" + f" {len(affected)} states with caps..." + ) + run_batch( + num_workers=num_workers, + area_filter=af, + force=True, + target_dir=STATE_TARGET_DIR, + weight_dir=STATE_WEIGHT_DIR, + ) + + # Clean up cap files for this iteration + _cleanup_caps(STATE_WEIGHT_DIR, affected) + else: + exhaustion, _ = _compute_exhaustion(STATE_WEIGHT_DIR) + n_still = int((exhaustion > max_exhaustion).sum()) + if n_still > 0: + print( + f"Warning: {n_still} records still exceed" + f" {max_exhaustion}x after" + f" {_MAX_EXHAUST_ITERATIONS} iterations" + f" (max={exhaustion.max():.2f}x)" + ) + elapsed = time.time() - t0 print(f"Total solve time: {elapsed:.1f}s") +def _compute_exhaustion(weight_dir): + """ + Compute per-record exhaustion across all state weight files. + + Returns (exhaustion_array, state_weights_dict). + """ + from tmd.storage import STORAGE_FOLDER + + s006 = pd.read_csv( + STORAGE_FOLDER / "output" / "tmd.csv.gz", + usecols=["s006"], + )["s006"].values + n = len(s006) + + weight_sum = np.zeros(n) + state_weights = {} + for wpath in sorted(weight_dir.glob("*_tmd_weights.csv.gz")): + area = wpath.name.split("_")[0] + w = pd.read_csv(wpath, usecols=[_WT_COL])[_WT_COL].values + weight_sum += w + state_weights[area] = w + + exhaustion = weight_sum / s006 + return exhaustion, state_weights + + +def _write_exhaustion_caps( + exhaustion, + state_weights, + max_exhaustion, + weight_dir, + target_dir, +): + """ + Compute per-record multiplier caps and write cap files. + + For over-exhausted records, scale each state's multiplier + cap proportionally to current usage so total exhaustion + equals max_exhaustion. + + Returns list of affected area codes. + """ + from tmd.storage import STORAGE_FOLDER + + s006 = pd.read_csv( + STORAGE_FOLDER / "output" / "tmd.csv.gz", + usecols=["s006"], + )["s006"].values + nat_pop = pd.read_csv( + STORAGE_FOLDER / "output" / "tmd.csv.gz", + usecols=["s006", "XTOT"], + ) + nat_pop = (nat_pop["s006"] * nat_pop["XTOT"]).sum() + + over_mask = exhaustion > max_exhaustion + if not over_mask.any(): + return [] + + # Scale factor per record: how much to shrink + scale = np.ones_like(exhaustion) + scale[over_mask] = max_exhaustion / exhaustion[over_mask] + + affected_areas = set() + for area, weights in state_weights.items(): + # Get pop_share for this area + tpath = target_dir / f"{area}_targets.csv" + if not tpath.exists(): + continue + targets_df = pd.read_csv(tpath, comment="#", nrows=1) + xtot_target = targets_df.iloc[0]["target"] + pop_share = xtot_target / nat_pop + + # Current multiplier for each record + w0 = pop_share * s006 + with np.errstate(divide="ignore", invalid="ignore"): + current_x = np.where(w0 > 0, weights / w0, 0.0) + + # New cap = current_x * scale (only tighten, never loosen) + new_caps = current_x * scale + # Don't exceed global max + new_caps = np.minimum(new_caps, AREA_MULTIPLIER_MAX) + + # Only write if some records are actually capped + n_capped = int((new_caps < AREA_MULTIPLIER_MAX - 1e-6).sum()) + if n_capped > 0: + caps_path = weight_dir / f"{area}_record_caps.npy" + np.save(caps_path, new_caps) + affected_areas.add(area) + + return sorted(affected_areas) + + +def _cleanup_caps(weight_dir, areas): + """Remove cap files after a solve pass.""" + for area in areas: + caps_path = weight_dir / f"{area}_record_caps.npy" + caps_path.unlink(missing_ok=True) + + def _parse_scope(scope): - """Parse scope string into a list of state codes or None for all.""" + """Parse scope string into list of state codes or None.""" _EXCLUDE = {"US", "PR", "OA"} scope_lower = scope.lower().strip() if scope_lower in ("states", "all"): @@ -80,12 +254,12 @@ def _parse_scope(scope): def main(): """CLI entry point.""" parser = argparse.ArgumentParser( - description="Solve for state weights using Clarabel QP optimizer", + description=("Solve for state weights using Clarabel QP optimizer"), ) parser.add_argument( "--scope", default="states", - help="'states' or comma-separated state codes (e.g., 'MN,CA,TX')", + help=("'states' or comma-separated state codes" " (e.g., 'MN,CA,TX')"), ) parser.add_argument( "--workers", @@ -97,7 +271,16 @@ def main(): "--force", action="store_true", default=True, - help="Recompute all areas even if up-to-date (default: True)", + help="Recompute all areas even if up-to-date", + ) + parser.add_argument( + "--max-exhaustion", + type=float, + default=None, + help=( + "Max per-record cross-state weight exhaustion" + " (e.g., 5.0). Runs iterative solve to enforce." + ), ) args = parser.parse_args() @@ -105,6 +288,7 @@ def main(): scope=args.scope, num_workers=args.workers, force=args.force, + max_exhaustion=args.max_exhaustion, ) diff --git a/tmd/areas/sweep_params.py b/tmd/areas/sweep_params.py new file mode 100644 index 00000000..386d12d0 --- /dev/null +++ b/tmd/areas/sweep_params.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 +""" +Parameter sweep for state weight solver. + +Tests combinations of multiplier_max and weight_penalty, +solving all 51 states for each combo and reporting: + - Target violations + - Weight exhaustion + - Solve time + +Usage: + python -m tmd.areas.sweep_params +""" + +import io +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed + +import numpy as np +import yaml + +from tmd.areas.create_area_weights_clarabel import ( + AREA_CONSTRAINT_TOL, + AREA_MAX_ITER, + AREA_MULTIPLIER_MIN, + AREA_SLACK_PENALTY, + POPFILE_PATH, + STATE_TARGET_DIR, + _build_constraint_matrix, + _drop_impossible_targets, + _load_taxcalc_data, + _solve_area_qp, +) +from tmd.areas.prepare.constants import ALL_STATES +from tmd.imputation_assumptions import TAXYEAR + +_WT_COL = f"WT{TAXYEAR}" + +# --- Parameter grid --- +GRID_MULTIPLIER_MAX = [10, 15, 25, 100] +GRID_WEIGHT_PENALTY = [1.0, 10.0, 100.0] + +NUM_WORKERS = 8 + +# Module-level cache +_VDF = None +_POP = None + + +def _init(): + """Load data once.""" + global _VDF, _POP # pylint: disable=global-statement + if _VDF is not None: + return + _VDF = _load_taxcalc_data() + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + _POP = yaml.safe_load(pf.read()) + + +def _solve_one(args): + """Solve one state with given params. Returns stats dict.""" + area, mult_max, wt_pen = args + _init() + vdf = _VDF + out = io.StringIO() + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=STATE_TARGET_DIR + ) + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) + + n_records = B_csc.shape[1] + t0 = time.time() + x_opt, _s_lo, _s_hi, info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=AREA_CONSTRAINT_TOL, + slack_penalty=AREA_SLACK_PENALTY, + max_iter=AREA_MAX_ITER, + multiplier_min=AREA_MULTIPLIER_MIN, + multiplier_max=mult_max, + weight_penalty=wt_pen, + out=out, + ) + elapsed = time.time() - t0 + + # Compute stats + achieved = np.asarray(B_csc @ x_opt).ravel() + rel_errors = np.abs(achieved - targets) / np.maximum(np.abs(targets), 1.0) + eps = 1e-9 + n_violated = int((rel_errors > AREA_CONSTRAINT_TOL + eps).sum()) + max_viol = float(rel_errors.max() * 100) + + # Weight stats + w0 = pop_share * vdf.s006.values + final_weights = x_opt * w0 + + return { + "area": area, + "mult_max": mult_max, + "wt_pen": wt_pen, + "n_targets": len(targets), + "n_violated": n_violated, + "max_viol_pct": max_viol, + "x_max": float(x_opt.max()), + "x_rmse": float(np.sqrt(np.mean((x_opt - 1.0) ** 2))), + "pct_zero": float((x_opt < 1e-6).mean() * 100), + "status": info["status"], + "elapsed": elapsed, + "pop_share": pop_share, + "final_weights": final_weights, + } + + +def run_sweep(): + """Run parameter sweep and print results.""" + print("Loading TMD data...") + _init() + + areas = [s.lower() for s in ALL_STATES] + combos = [ + (mm, wp) for mm in GRID_MULTIPLIER_MAX for wp in GRID_WEIGHT_PENALTY + ] + + s006 = _VDF.s006.values + n_records = len(s006) + + print( + f"Sweep: {len(combos)} parameter combos" + f" x {len(areas)} states" + f" = {len(combos) * len(areas)} solves" + ) + print( + f"Grid: mult_max={GRID_MULTIPLIER_MAX}," + f" weight_penalty={GRID_WEIGHT_PENALTY}" + ) + print(f"Workers: {NUM_WORKERS}") + print() + + results = [] + + for combo_idx, (mm, wp) in enumerate(combos): + label = f"mult_max={mm:>3}, wt_pen={wp:>5.1f}" + sys.stdout.write(f"[{combo_idx + 1}/{len(combos)}] {label}...") + sys.stdout.flush() + + t0 = time.time() + tasks = [(area, mm, wp) for area in areas] + + combo_results = [] + with ProcessPoolExecutor( + max_workers=NUM_WORKERS, + initializer=_init, + ) as executor: + futures = { + executor.submit(_solve_one, task): task for task in tasks + } + for future in as_completed(futures): + combo_results.append(future.result()) + + elapsed = time.time() - t0 + + # Aggregate stats + total_violated = sum(r["n_violated"] for r in combo_results) + max_viol = max(r["max_viol_pct"] for r in combo_results) + avg_rmse = np.mean([r["x_rmse"] for r in combo_results]) + avg_zero = np.mean([r["pct_zero"] for r in combo_results]) + + # Compute exhaustion + weight_sum = np.zeros(n_records) + for r in combo_results: + weight_sum += r["final_weights"] + exhaustion = weight_sum / s006 + max_exh = float(exhaustion.max()) + p99_exh = float(np.percentile(exhaustion, 99)) + n_over5 = int((exhaustion > 5).sum()) + n_over10 = int((exhaustion > 10).sum()) + + summary = { + "mult_max": mm, + "wt_pen": wp, + "total_violated": total_violated, + "max_viol_pct": max_viol, + "avg_rmse": avg_rmse, + "avg_pct_zero": avg_zero, + "max_exhaustion": max_exh, + "p99_exhaustion": p99_exh, + "n_over5x": n_over5, + "n_over10x": n_over10, + "elapsed": elapsed, + } + results.append(summary) + + sys.stdout.write( + f" {elapsed:.0f}s |" + f" viol={total_violated:>4}" + f" maxV={max_viol:.2f}%" + f" wRMSE={avg_rmse:.3f}" + f" %zero={avg_zero:.1f}%" + f" maxExh={max_exh:.1f}" + f" >5x={n_over5}" + f" >10x={n_over10}\n" + ) + + # Summary table + print("\n" + "=" * 100) + print("PARAMETER SWEEP RESULTS") + print("=" * 100) + cols = [ + "mMax", + "wPen", + "Viol", + "MaxV%", + "wRMSE", + "%zero", + "MaxExh", + "p99Exh", + "O5x", + "O10x", + "Time", + ] + widths = [5, 6, 5, 6, 6, 6, 7, 7, 5, 5, 6] + hdr = " ".join(f"{c:>{w}}" for c, w in zip(cols, widths)) + print(hdr) + print("-" * len(hdr)) + for r in results: + mm = r["mult_max"] + wp = r["wt_pen"] + tv = r["total_violated"] + mv = r["max_viol_pct"] + rm = r["avg_rmse"] + pz = r["avg_pct_zero"] + me = r["max_exhaustion"] + pe = r["p99_exhaustion"] + o5 = r["n_over5x"] + o10 = r["n_over10x"] + et = r["elapsed"] + print( + f"{mm:>5} {wp:>6.1f}" + f" {tv:>5} {mv:>6.2f}" + f" {rm:>6.3f} {pz:>6.1f}" + f" {me:>7.1f} {pe:>7.3f}" + f" {o5:>5} {o10:>5}" + f" {et:>5.0f}s" + ) + print() + print("Baseline: mult_max=100, wt_pen=1.0") + + +if __name__ == "__main__": + run_sweep() From 5fe41e79701d8307bbd554f33248194f3eeb80b1 Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 13:52:31 -0400 Subject: [PATCH 04/10] Set multiplier_max=25 and add area weighting lessons Lower AREA_MULTIPLIER_MAX from 100 to 25 based on parameter sweep: virtually identical target accuracy (35 vs 33 violations) but 34% lower max exhaustion (16.6x vs 25.2x). Single-pass, no complexity. Add AREA_WEIGHTING_LESSONS.md documenting parameter tuning findings, weight exhaustion mechanics, dual variable analysis, SALT targeting, and guidance for future Congressional district work. Update README.md with solver usage, quality report, and link to lessons document. Co-Authored-By: Claude Opus 4.6 (1M context) --- tmd/areas/AREA_WEIGHTING_LESSONS.md | 170 ++++++++++++++++++++++ tmd/areas/README.md | 47 +++++- tmd/areas/create_area_weights_clarabel.py | 2 +- 3 files changed, 217 insertions(+), 2 deletions(-) create mode 100644 tmd/areas/AREA_WEIGHTING_LESSONS.md diff --git a/tmd/areas/AREA_WEIGHTING_LESSONS.md b/tmd/areas/AREA_WEIGHTING_LESSONS.md new file mode 100644 index 00000000..ba4837fe --- /dev/null +++ b/tmd/areas/AREA_WEIGHTING_LESSONS.md @@ -0,0 +1,170 @@ +# Area Weighting: Lessons Learned + +Practical guidance from developing the state weight optimization +pipeline. Intended for future maintainers and anyone extending +this to new geographies (e.g., Congressional districts). + +## The optimization problem + +For each sub-national area, we find weight multipliers `x_i` for +every record such that weighted sums match area targets within a +tolerance band, while keeping multipliers close to 1.0 +(population-proportional). + +``` +minimize sum((x_i - 1)^2) [stay close to proportional] +subject to target*(1-tol) <= Bx <= target*(1+tol) [hit targets] + 0 <= x_i <= multiplier_max [bounds] +``` + +Solved independently per area using Clarabel (constrained QP with +elastic slack for infeasibility). + +## Key parameters and what they do + +| Parameter | Default | Effect | +|-----------|---------|--------| +| `AREA_CONSTRAINT_TOL` | 0.005 (0.5%) | Target tolerance band. Matches national reweighting. | +| `AREA_MULTIPLIER_MAX` | 25.0 | Per-record upper bound on weight multiplier. Most important lever for controlling exhaustion. | +| `AREA_SLACK_PENALTY` | 1e6 | Penalty on constraint slack. Very high = hard constraints. | +| `weight_penalty` | 1.0 | Penalty on `(x-1)^2`. Higher values keep multipliers closer to 1.0. | + +## Weight exhaustion + +**Definition**: For each record, exhaustion = (sum of all area +weights) / (national weight). A value of 1.0 means the record's +national weight is fully allocated across areas. Values above 1.0 +mean the record is "oversubscribed" — used more in total than its +national weight warrants. + +**Why it matters**: High exhaustion means a small number of records +are doing heavy lifting across many states, creating fragile +solutions where one record's characteristics drive multiple states' +estimates. + +**What drives it**: Rare high-income PUF records with small national +weights (s006 ~ 10-50) get pulled by many states to hit their +high-AGI-bin targets. These are typically wealthy MFJ households +with large investment income (interest, dividends, capital gains, +partnership income). + +### Exhaustion statistics (2022 SOI, 178 targets/state) + +| Percentile | Exhaustion | +|------------|-----------| +| Median | 1.007 | +| p99 | 2.0 | +| p99.9 | 4.4 | +| Max | 25.2 (at mult_max=100) | +| Max | 16.6 (at mult_max=25) | + +Only ~150 records exceed 5x. The problem is concentrated in +the extreme tail. + +## Parameter sweep results (2023 tax year, 51 states) + +Tested `multiplier_max` x `weight_penalty` grid. Key finding: +**`weight_penalty` has no effect** on exhaustion, wRMSE, or %zero +— it only increases target violations. The solver reaches the +same weight structure regardless; higher penalty just makes it +fail to meet more targets. + +**`multiplier_max` is the only effective lever:** + +| mult_max | Violations | MaxViol% | wRMSE | %zero | MaxExh | >10x | +|----------|-----------|---------|-------|-------|--------|------| +| 100 | 33 | 0.50% | 0.594 | 7.3% | 25.2 | 16 | +| **25** | **35** | **0.50%** | **0.609** | **7.8%** | **16.6** | **15** | +| 15 | 210 | 100% | 0.601 | 9.9% | 11.9 | 3 | +| 10 | 402 | 100% | 0.625 | 12.2% | 8.8 | 0 | + +**`mult_max=25` is the sweet spot**: only 2 extra violations +vs baseline, max violation still at the 0.50% tolerance boundary, +but max exhaustion drops 34% (25.2 to 16.6). Below 25, targets +become infeasible (100% violations appear). + +## Two-pass exhaustion limiting (not recommended for production) + +We tested an iterative approach: solve unconstrained, compute +exhaustion, set per-record caps, re-solve. With `max_exhaustion=5`: + +- Pass 1: 33 violations (normal) +- Pass 2: **8,979 violations**, max **100%** — catastrophic + +The proportional cap scaling was too aggressive. Records at 25x +got scaled to 20% of their multiplier, making targets infeasible. +The approach is fragile and hard to tune. + +**Recommendation**: Use `multiplier_max` (single pass) rather +than iterative exhaustion capping. Simpler, more robust, and +no slower. + +## Dual variable analysis (constraint cost) + +Clarabel's dual variables reveal which constraints are expensive +to satisfy — a target can be perfectly hit but still cause massive +weight distortion. This is invisible from violations alone. + +**Finding**: The $1M+ AGI bin filing-status count targets (single, +MFJ, HoH) are essentially the **only expensive constraints**. Their +dual costs are 6-8 orders of magnitude larger than all other targets. +Extended targets (capital gains, pensions, charitable, etc.) have +near-zero dual cost — they're "free" to add. + +**Action taken**: Excluded filing-status counts from the $1M+ bin. +Kept total return count as an anchor. This eliminated virtually all +constraint cost while maintaining full target coverage. + +## SALT targeting + +Use **actual SALT** (`c18300`, the post-$10K-cap amount) rather than +uncapped potential SALT (`e18400`/`e18500`). This is apples-to-apples: +Tax-Calculator's actual SALT under current law, shared using observed +SOI actual SALT by state. + +| Metric | e18400/e18500 | c18300 | +|--------|--------------|--------| +| Violations | 85 | 75 | +| wRMSE avg | 0.479 | 0.289 | +| SALT aggregation error | -6.23% | -0.18% | + +## OA (Other Areas) share rescaling + +SOI's "Other Areas" category (~0.5% of returns, covers territories +and overseas filers) is excluded from state targeting. Raw SOI +shares (state/US) are rescaled so that the 51 state shares sum +to 1.0 for each variable-AGI-bin combination. + +## Guidance for Congressional Districts + +CDs have not been implemented yet. Expected differences from states: + +- **435 areas** (vs 51) — grid search infeasible; use state-derived + parameter settings +- **9 AGI bins** (vs 10) — no $1M+ separate bin in SOI CD data +- **Smaller populations** — more CDs will hit multiplier ceilings; + may need different `multiplier_max` +- **Crosswalk complexity** — SOI uses 117th Congress boundaries for + both 2021 and 2022 data; need geocorr crosswalk to map to 118th + Congress boundaries +- **Exhaustion will be worse** — 435 areas competing for the same + records means much higher potential exhaustion; `multiplier_max` + may need to be lower + +## Running the parameter sweep + +The sweep utility tests combinations of solver parameters on all +states and reports target accuracy, weight distortion, and exhaustion: + +```bash +python -m tmd.areas.sweep_params +``` + +Edit `GRID_MULTIPLIER_MAX` and `GRID_WEIGHT_PENALTY` in +`tmd/areas/sweep_params.py` to change the search grid. Each +combination solves all 51 states (~3-4 minutes with 8 workers). + +**Should be rerun when**: +- The tax year changes (different TMD data, different SOI shares) +- The target recipe changes materially +- Extending to a new area type (CDs) diff --git a/tmd/areas/README.md b/tmd/areas/README.md index 3b49816c..7a489a27 100644 --- a/tmd/areas/README.md +++ b/tmd/areas/README.md @@ -51,14 +51,59 @@ charitable contributions, SALT by source (Census), EITC, CTC. Filing-status count targets are excluded from the $1M+ AGI bin to avoid excessive weight distortion on small cells. +## Solving for State Weights + +```bash +# All 51 states, 8 parallel workers: +python -m tmd.areas.solve_weights --scope states --workers 8 + +# Specific states: +python -m tmd.areas.solve_weights --scope MN,CA,TX --workers 4 +``` + +Uses the Clarabel constrained QP solver to find per-record weight +multipliers that hit each state's targets within 0.5% tolerance. + +Output: weight files in `tmd/areas/weights/states/` and solver +logs alongside them. + +## Quality Report + +```bash +python -m tmd.areas.quality_report +python -m tmd.areas.quality_report --scope CA,NY +``` + +Cross-state summary: solve status, target accuracy, weight +distortion, weight exhaustion, and national aggregation checks. + ## Pipeline Modules +**Target preparation** (PR 1): + | Module | Purpose | |--------|---------| | `prepare_targets.py` | CLI entry point | | `prepare/soi_state_data.py` | SOI state CSV ingestion | -| `prepare/target_sharing.py` | TMD × SOI share computation | +| `prepare/target_sharing.py` | TMD x SOI share computation | | `prepare/target_file_writer.py` | Recipe expansion, CSV output | | `prepare/extended_targets.py` | Census/SOI extended targets | | `prepare/constants.py` | AGI bins, mappings, metadata | | `prepare/census_population.py` | State population data | + +**Weight solving** (PR 2): + +| Module | Purpose | +|--------|---------| +| `solve_weights.py` | CLI entry point | +| `create_area_weights_clarabel.py` | Clarabel QP solver | +| `batch_weights.py` | Parallel batch runner | +| `quality_report.py` | Cross-state diagnostics | +| `sweep_params.py` | Parameter grid search utility | + +## Lessons Learned + +See [AREA_WEIGHTING_LESSONS.md](AREA_WEIGHTING_LESSONS.md) for +practical guidance on parameter tuning, weight exhaustion, SALT +targeting, dual variable analysis, and recommendations for +extending to Congressional districts. diff --git a/tmd/areas/create_area_weights_clarabel.py b/tmd/areas/create_area_weights_clarabel.py index 15878fed..1641237f 100644 --- a/tmd/areas/create_area_weights_clarabel.py +++ b/tmd/areas/create_area_weights_clarabel.py @@ -67,7 +67,7 @@ AREA_SLACK_PENALTY = 1e6 AREA_MAX_ITER = 2000 AREA_MULTIPLIER_MIN = 0.0 -AREA_MULTIPLIER_MAX = 100.0 +AREA_MULTIPLIER_MAX = 25.0 # Default target/weight directories for states STATE_TARGET_DIR = AREAS_FOLDER / "targets" / "states" From 90a4987ae00c7067b2321744f8d6731e22db9ab7 Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 14:02:23 -0400 Subject: [PATCH 05/10] Show RECID instead of row index in exhaustion report Co-Authored-By: Claude Opus 4.6 (1M context) --- tmd/areas/quality_report.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py index 3d2c2ae6..f873543e 100644 --- a/tmd/areas/quality_report.py +++ b/tmd/areas/quality_report.py @@ -443,6 +443,7 @@ def _weight_diagnostics(areas, weight_dir=None): # Load national data tmd_path = STORAGE_FOLDER / "output" / "tmd.csv.gz" tmd_cols = [ + "RECID", "s006", "MARS", "XTOT", @@ -537,8 +538,9 @@ def _weight_diagnostics(areas, weight_dir=None): st_wts.append((st, w[idx])) st_wts.sort(key=lambda x: -x[1]) top3 = ", ".join(f"{st}={wt:.1f}" for st, wt in st_wts[:3]) + recid = int(r.RECID) lines.append( - f" {rank}. rec {idx}: exh={exh:.1f}x," + f" {rank}. RECID {recid}: exh={exh:.1f}x," f" s006={r.s006:.1f}," f" {ds} {mars}," f" AGI=${agi:,.0f}" From c2ce4ae45dde3f5d5fe084034f6cb94638afc7c6 Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 14:13:13 -0400 Subject: [PATCH 06/10] Add bystander check for untargeted variable distortion Quality report now checks ~19 untargeted variables for cross-state aggregation distortion, sorted by severity, with >2% flagged. Key bystanders: student loan interest (-10.5%), AMT (-10.3%), tax-exempt interest (+7.4%), qualified dividends (-4.1%). Add corresponding section to AREA_WEIGHTING_LESSONS.md explaining what drives bystander distortion and when to worry about it. Co-Authored-By: Claude Opus 4.6 (1M context) --- tmd/areas/AREA_WEIGHTING_LESSONS.md | 43 +++++++++ tmd/areas/quality_report.py | 138 +++++++++++++++++++++++++++- 2 files changed, 177 insertions(+), 4 deletions(-) diff --git a/tmd/areas/AREA_WEIGHTING_LESSONS.md b/tmd/areas/AREA_WEIGHTING_LESSONS.md index ba4837fe..e0d891dc 100644 --- a/tmd/areas/AREA_WEIGHTING_LESSONS.md +++ b/tmd/areas/AREA_WEIGHTING_LESSONS.md @@ -128,6 +128,49 @@ SOI actual SALT by state. | wRMSE avg | 0.479 | 0.289 | | SALT aggregation error | -6.23% | -0.18% | +## Bystander variables (untargeted collateral distortion) + +Area weighting optimizes ~178 targets per state. Variables not +in the target set can be distorted as a side effect — "innocent +bystanders" pulled by weight adjustments aimed at other variables. + +The quality report includes a bystander check showing cross-state +aggregation error (sum-of-states vs national) for untargeted +variables. Variables with >2% distortion are flagged. + +### Observed bystanders (2023, mult_max=25) + +| Variable | Diff% | Why | +|----------|-------|-----| +| Student loan int (e19200) | -10.5% | Concentrated on middle-income records that get under-weighted | +| AMT (c09600) | -10.3% | Very small base ($1.1B); volatile under any reweighting | +| Tax-exempt int (e00400) | +7.4% | On same wealthy records that get over-weighted for AGI targets | +| Qualified dividends (e00650) | -4.1% | Total dividends (e00600) targeted but not qualified portion | +| Unemployment comp (e02300) | -4.0% | Low-income variable squeezed when high-income records upweighted | +| Medical expenses (e17500) | -2.9% | Itemized deduction bystander | +| Total credits (c07100) | -2.2% | EITC/CTC targeted but not total credits | + +### Well-behaved bystanders (<1%) + +Income tax, payroll tax, standard deduction, itemized total, +Sch C, IRA distributions, taxable pensions, Sch E, total persons, +children under 17. These are either targeted directly or highly +correlated with targeted variables. + +### When to worry + +Bystander distortion matters when the variable is **policy-relevant +at the state level** and **not closely correlated** with any targeted +variable. Student loan interest at -10.5% would matter for a state- +level student debt analysis; AMT at -10.3% on a $1.1B base is less +consequential. + +If a bystander becomes important for a specific analysis, the fix is +to add it (or a correlated variable) as a target. The dual variable +analysis showed that extended targets are essentially "free" — near- +zero constraint cost — so adding targets is low-risk as long as they +don't conflict with existing targets in small AGI bins. + ## OA (Other Areas) share rescaling SOI's "Other Areas" category (~0.5% of returns, covers territories diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py index f873543e..626ebf1f 100644 --- a/tmd/areas/quality_report.py +++ b/tmd/areas/quality_report.py @@ -450,8 +450,19 @@ def _weight_diagnostics(areas, weight_dir=None): "data_source", "e00200", "e00300", + "e00400", "e00600", + "e00650", + "e00900", + "e01400", + "e01700", + "e02000", + "e02300", + "e02400", + "e17500", + "e19200", "e26270", + "n24", "p22250", "p23250", ] @@ -466,10 +477,22 @@ def _weight_diagnostics(areas, weight_dir=None): allvars_path = STORAGE_FOLDER / "output" / "cached_allvars.csv" if allvars_path.exists(): - needed_tc = ["c18300", "iitax"] - allvars = pd.read_csv(allvars_path, usecols=needed_tc) - for col in needed_tc: - if col in allvars.columns: + needed_tc = [ + "c04470", + "c07100", + "c09600", + "c18300", + "c19200", + "c19700", + "iitax", + "payrolltax", + "standard", + ] + avail = pd.read_csv(allvars_path, nrows=0).columns + load_tc = [c for c in needed_tc if c in avail] + if load_tc: + allvars = pd.read_csv(allvars_path, usecols=load_tc) + for col in load_tc: tmd[col] = allvars[col].values # Load all state weights once @@ -616,6 +639,113 @@ def _weight_diagnostics(areas, weight_dir=None): ) lines.append("") + # Bystander check: untargeted variables + lines.extend(_bystander_check(tmd, s006, state_weights, n_loaded)) + + return lines + + +def _bystander_check(tmd, s006, state_weights, n_loaded): + """ + Check untargeted variables for cross-state aggregation + distortion. These are 'innocent bystanders' that may be + jerked around by weight adjustments aimed at targeted + variables. + """ + lines = [] + + # Untargeted variables to check, grouped by category + # Format: (label, varname, is_count) + bystander_vars = [ + # Tax liability / credits + ("Income tax (iitax)", "iitax", False), + ("Payroll tax", "payrolltax", False), + ("AMT (c09600)", "c09600", False), + ("Total credits (c07100)", "c07100", False), + # Deductions + ("Medical expenses (e17500)", "e17500", False), + ("Student loan int (e19200)", "e19200", False), + ("Itemized ded (c04470)", "c04470", False), + ("Standard deduction", "standard", False), + ("Mortgage int (c19200)", "c19200", False), + ("Charitable (c19700)", "c19700", False), + # Income not directly targeted + ("Tax-exempt int (e00400)", "e00400", False), + ("Qual dividends (e00650)", "e00650", False), + ("Sch C income (e00900)", "e00900", False), + ("IRA distrib (e01400)", "e01400", False), + ("Taxable pensions (e01700)", "e01700", False), + ("Sch E net (e02000)", "e02000", False), + ("Unemployment (e02300)", "e02300", False), + # Demographics + ("Total persons (XTOT)", "XTOT", True), + ("Children <17 (n24)", "n24", True), + ] + + # Compute national and sum-of-states for each + results = [] + for label, var, is_count in bystander_vars: + if var not in tmd.columns: + continue + if var == "XTOT": + nat = float((s006 * tmd[var].values).sum()) + elif is_count: + nat = float((s006 * tmd[var].values).sum()) + else: + nat = float((s006 * tmd[var].values).sum()) + if abs(nat) < 1: + continue + sos = 0.0 + for _st, w in state_weights.items(): + sos += float((w * tmd[var].values).sum()) + diff_pct = (sos / nat - 1) * 100 + results.append((label, var, is_count, nat, sos, diff_pct)) + + # Sort by absolute distortion + results.sort(key=lambda x: -abs(x[5])) + + lines.append( + "BYSTANDER CHECK: UNTARGETED VARIABLES" f" ({n_loaded} states):" + ) + lines.append( + " Variables NOT directly targeted — distortion" + " from weight adjustments." + ) + lines.append( + f" {'Variable':<30} {'National':>16}" + f" {'Sum-of-States':>16} {'Diff%':>8}" + ) + lines.append(" " + "-" * 72) + for label, _var, is_count, nat, sos, diff_pct in results: + flag = " ***" if abs(diff_pct) > 2 else "" + if is_count: + lines.append( + f" {label:<30} {nat:>16,.0f}" + f" {sos:>16,.0f}" + f" {diff_pct:>+7.2f}%{flag}" + ) + else: + lines.append( + f" {label:<30}" + f" ${nat / 1e9:>14.1f}B" + f" ${sos / 1e9:>14.1f}B" + f" {diff_pct:>+7.2f}%{flag}" + ) + + n_flagged = sum(1 for *_, d in results if abs(d) > 2) + lines.append("") + if n_flagged: + lines.append( + f" *** = {n_flagged} variables with" + f" >2% aggregation distortion" + ) + else: + lines.append( + " All untargeted variables within" + + " 2% aggregation tolerance." + ) + lines.append("") + return lines From 81c114c5412525405a7dbc9df136ac862e71eee4 Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 14:48:32 -0400 Subject: [PATCH 07/10] Fix black formatting for CI (black 26.3.0 compatibility) Co-Authored-By: Claude Opus 4.6 (1M context) --- tmd/areas/quality_report.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py index 626ebf1f..037352d9 100644 --- a/tmd/areas/quality_report.py +++ b/tmd/areas/quality_report.py @@ -741,8 +741,7 @@ def _bystander_check(tmd, s006, state_weights, n_loaded): ) else: lines.append( - " All untargeted variables within" - + " 2% aggregation tolerance." + " All untargeted variables within" + " 2% aggregation tolerance." ) lines.append("") From 0194fbf93e1acc80cd220850b83cfa8ed5c47c0b Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 18:19:48 -0400 Subject: [PATCH 08/10] Replace old scipy solver with Clarabel QP solver Remove create_area_weights_clarabel.py and replace the old create_area_weights.py (scipy L-BFGS-B + JAX) with the Clarabel constrained QP solver. Update all imports and function references. - Renamed create_area_weights_file_clarabel() to create_area_weights_file() - Removed valid_area() dependency (areas validated by target file existence) - Updated imports in solve_weights, batch_weights, quality_report, sweep_params, make_all, and test_solve_weights Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_solve_weights.py | 8 +- tmd/areas/batch_weights.py | 20 +- tmd/areas/create_area_weights.py | 1121 ++++++++++----------- tmd/areas/create_area_weights_clarabel.py | 609 ----------- tmd/areas/make_all.py | 6 +- tmd/areas/quality_report.py | 2 +- tmd/areas/solve_weights.py | 2 +- tmd/areas/sweep_params.py | 2 +- 8 files changed, 525 insertions(+), 1245 deletions(-) delete mode 100644 tmd/areas/create_area_weights_clarabel.py diff --git a/tests/test_solve_weights.py b/tests/test_solve_weights.py index 2b52af17..a30165b3 100644 --- a/tests/test_solve_weights.py +++ b/tests/test_solve_weights.py @@ -18,13 +18,13 @@ from tmd.areas import AREAS_FOLDER from tmd.areas.batch_weights import _filter_areas -from tmd.areas.create_area_weights_clarabel import ( +from tmd.areas.create_area_weights import ( AREA_CONSTRAINT_TOL, _build_constraint_matrix, _drop_impossible_targets, _load_taxcalc_data, _solve_area_qp, - create_area_weights_file_clarabel, + create_area_weights_file, ) from tmd.areas.quality_report import ( _humanize_desc, @@ -47,14 +47,14 @@ def test_clarabel_solver_xx(): with tempfile.TemporaryDirectory() as tmpdir: weight_dir = Path(tmpdir) - rc = create_area_weights_file_clarabel( + rc = create_area_weights_file( "xx", write_log=True, write_file=True, target_dir=target_dir, weight_dir=weight_dir, ) - assert rc == 0, "create_area_weights_file_clarabel returned non-zero" + assert rc == 0, "create_area_weights_file returned non-zero" # Verify weights file was created wpath = weight_dir / "xx_tmd_weights.csv.gz" diff --git a/tmd/areas/batch_weights.py b/tmd/areas/batch_weights.py index aa3730b4..9a5c1f45 100644 --- a/tmd/areas/batch_weights.py +++ b/tmd/areas/batch_weights.py @@ -23,8 +23,6 @@ import pandas as pd import yaml -from tmd.areas.create_area_weights import valid_area - # Module-level cache for TMD data (one per worker process) _WORKER_VDF = None _WORKER_POP = None @@ -42,7 +40,7 @@ def _init_worker(target_dir=None, weight_dir=None): _WORKER_WEIGHT_DIR = Path(weight_dir) if _WORKER_VDF is not None: return - from tmd.areas.create_area_weights_clarabel import ( + from tmd.areas.create_area_weights import ( POPFILE_PATH, _load_taxcalc_data, ) @@ -59,7 +57,7 @@ def _solve_one_area(area): Returns (area, elapsed, n_targets, n_violated, status, max_viol_pct). """ _init_worker() - from tmd.areas.create_area_weights_clarabel import ( + from tmd.areas.create_area_weights import ( AREA_CONSTRAINT_TOL, AREA_MAX_ITER, AREA_MULTIPLIER_MAX, @@ -182,19 +180,11 @@ def _solve_one_area(area): def _list_target_areas(target_dir=None): """Return sorted list of area codes with target files.""" - from tmd.areas.create_area_weights_clarabel import STATE_TARGET_DIR + from tmd.areas.create_area_weights import STATE_TARGET_DIR tfolder = target_dir or STATE_TARGET_DIR tpaths = sorted(tfolder.glob("*_targets.csv")) - areas = [] - for tpath in tpaths: - area = tpath.name.split("_")[0] - old_stderr = sys.stderr - sys.stderr = io.StringIO() - ok = valid_area(area) - sys.stderr = old_stderr - if ok: - areas.append(area) + areas = [tpath.name.split("_")[0] for tpath in tpaths] return areas @@ -234,7 +224,7 @@ def run_batch( weight_dir : Path, optional Directory for weight output. """ - from tmd.areas.create_area_weights_clarabel import ( + from tmd.areas.create_area_weights import ( STATE_TARGET_DIR, STATE_WEIGHT_DIR, ) diff --git a/tmd/areas/create_area_weights.py b/tmd/areas/create_area_weights.py index eb046bc6..f50cb35e 100644 --- a/tmd/areas/create_area_weights.py +++ b/tmd/areas/create_area_weights.py @@ -1,705 +1,606 @@ """ -Construct AREA_tmd_weights.csv.gz, a Tax-Calculator-style weights file -for FIRST_YEAR through LAST_YEAR for the specified sub-national AREA. - -AREA prefix for state areas are the two lower-case character postal codes. -AREA prefix for congressional districts are the state prefix followed by -two digits (with a leading zero) identifying the district. States with -only one congressional district have 00 as the two digits to align names -with IRS data. +Clarabel-based constrained QP area weight optimization. + +Finds weight multipliers (x) for each record such that area-specific +weighted sums match area targets within a tolerance band, while +minimizing deviation from population-proportional weights. + +Formulation: + minimize sum((x_i - 1)^2) + subject to t_j*(1-eps) <= (Bx)_j <= t_j*(1+eps) [target bounds] + x_min <= x_i <= x_max [multiplier bounds] + +where: + x_i = area_weight_i / (pop_share * national_weight_i) + pop_share = area_population / national_population + B[j,i] = (pop_share * national_weight_i) * A[j,i] + + x_i = 1 means record i gets its population-proportional share. + The optimizer adjusts x_i to hit area-specific targets. + +Elastic slack variables handle infeasibility gracefully: + minimize sum((x_i - 1)^2) + M * sum(s^2) + subject to lb_j <= (Bx)_j + s_lo - s_hi <= ub_j, s >= 0 + +Follows the same QP construction as tmd/utils/reweight.py. """ import sys -import re import time -import yaml + +import clarabel import numpy as np import pandas as pd -from scipy.sparse import csr_matrix -from scipy.optimize import minimize, Bounds -import jax -import jax.numpy as jnp -from jax.experimental.sparse import BCOO -from tmd.storage import STORAGE_FOLDER -from tmd.imputation_assumptions import TAXYEAR, POPULATION_FILE +import yaml +from scipy.sparse import ( + csc_matrix, + diags as spdiags, + eye as speye, + hstack, + vstack, +) + from tmd.areas import AREAS_FOLDER +from tmd.imputation_assumptions import POPULATION_FILE, TAXYEAR +from tmd.storage import STORAGE_FOLDER FIRST_YEAR = TAXYEAR LAST_YEAR = 2034 INFILE_PATH = STORAGE_FOLDER / "output" / "tmd.csv.gz" POPFILE_PATH = STORAGE_FOLDER / "input" / POPULATION_FILE - -# Tax-Calcultor calculated variable cache files: TAXCALC_AGI_CACHE = STORAGE_FOLDER / "output" / "cached_c00100.npy" - -PARAMS = {} - -# default target parameters: -TARGET_RATIO_TOLERANCE = 0.0040 # what is considered hitting the target -DUMP_ALL_TARGET_DEVIATIONS = False # set to True only for diagnostic work - -# default regularization parameters: -DELTA_INIT_VALUE = 1.0e-9 -DELTA_MAX_LOOPS = 1 - -# default optimization parameters: -OPTIMIZE_FTOL = 1e-9 -OPTIMIZE_GTOL = 1e-9 -OPTIMIZE_MAXITER = 5000 -OPTIMIZE_IPRINT = 0 # 20 is a good diagnostic value; set to 0 for no output -OPTIMIZE_RESULTS = False # set to True to see complete optimization results - - -def valid_area(area: str): - """ - Check validity of area string returning a boolean value. - """ - # Census on which Congressional districts are based: - # : cd_census_year = 2010 implies districts are for the 117th Congress - # : cd_census_year = 2020 implies districts are for the 118th Congress - cd_census_year = 2020 - # data in the state_info dictionary is taken from the following document: - # 2020 Census Apportionment Results, April 26, 2021, - # Table C1. Number of Seats in - # U.S. House of Representatives by State: 1910 to 2020 - # https://www.census.gov/data/tables/2020/dec/2020-apportionment-data.html - state_info = { - # number of Congressional districts per state indexed by cd_census_year - "AL": {2020: 7, 2010: 7}, - "AK": {2020: 1, 2010: 1}, - "AZ": {2020: 9, 2010: 9}, - "AR": {2020: 4, 2010: 4}, - "CA": {2020: 52, 2010: 53}, - "CO": {2020: 8, 2010: 7}, - "CT": {2020: 5, 2010: 5}, - "DE": {2020: 1, 2010: 1}, - "DC": {2020: 0, 2010: 0}, - "FL": {2020: 28, 2010: 27}, - "GA": {2020: 14, 2010: 14}, - "HI": {2020: 2, 2010: 2}, - "ID": {2020: 2, 2010: 2}, - "IL": {2020: 17, 2010: 18}, - "IN": {2020: 9, 2010: 9}, - "IA": {2020: 4, 2010: 4}, - "KS": {2020: 4, 2010: 4}, - "KY": {2020: 6, 2010: 6}, - "LA": {2020: 6, 2010: 6}, - "ME": {2020: 2, 2010: 2}, - "MD": {2020: 8, 2010: 8}, - "MA": {2020: 9, 2010: 9}, - "MI": {2020: 13, 2010: 14}, - "MN": {2020: 8, 2010: 8}, - "MS": {2020: 4, 2010: 4}, - "MO": {2020: 8, 2010: 8}, - "MT": {2020: 2, 2010: 1}, - "NE": {2020: 3, 2010: 3}, - "NV": {2020: 4, 2010: 4}, - "NH": {2020: 2, 2010: 2}, - "NJ": {2020: 12, 2010: 12}, - "NM": {2020: 3, 2010: 3}, - "NY": {2020: 26, 2010: 27}, - "NC": {2020: 14, 2010: 13}, - "ND": {2020: 1, 2010: 1}, - "OH": {2020: 15, 2010: 16}, - "OK": {2020: 5, 2010: 5}, - "OR": {2020: 6, 2010: 5}, - "PA": {2020: 17, 2010: 18}, - "RI": {2020: 2, 2010: 2}, - "SC": {2020: 7, 2010: 7}, - "SD": {2020: 1, 2010: 1}, - "TN": {2020: 9, 2010: 9}, - "TX": {2020: 38, 2010: 36}, - "UT": {2020: 4, 2010: 4}, - "VT": {2020: 1, 2010: 1}, - "VA": {2020: 11, 2010: 11}, - "WA": {2020: 10, 2010: 10}, - "WV": {2020: 2, 2010: 3}, - "WI": {2020: 8, 2010: 8}, - "WY": {2020: 1, 2010: 1}, - } - # check state_info validity - assert len(state_info) == 50 + 1 - total = {2010: 0, 2020: 0} - for _, seats in state_info.items(): - total[2010] += seats[2010] - total[2020] += seats[2020] - assert total[2010] == 435 - assert total[2020] == 435 - compare_new_vs_old = False - if compare_new_vs_old: - text = "state,2010cds,2020cds" - for state, seats in state_info.items(): - if seats[2020] != seats[2010]: - print(f"{text}= {state} {seats[2010]:2d} {seats[2020]:2d}") - sys.exit(1) - # conduct series of validity checks on specified area string - # ... check that specified area string has expected length - len_area_str = len(area) - if not 2 <= len_area_str <= 5: - sys.stderr.write(f": area '{area}' length is not in [2,5] range\n") - return False - # ... check first two characters of area string - s_c = area[0:2] - if not re.match(r"[a-z][a-z]", s_c): - emsg = "begin with two lower-case letters" - sys.stderr.write(f": area '{area}' must {emsg}\n") - return False - is_faux_area = re.match(r"[x-z][a-z]", s_c) is not None - if not is_faux_area: - if s_c.upper() not in state_info: - sys.stderr.write(f": state '{s_c}' is unknown\n") - return False - # ... check state area assumption letter which is optional - if len_area_str == 3: - assump_char = area[2:3] - if not re.match(r"[A-Z]", assump_char): - emsg = "assumption character that is not an upper-case letter" - sys.stderr.write(f": area '{area}' has {emsg}\n") - return False - if len_area_str <= 3: - return True - # ... check Congressional district area string - if not re.match(r"\d\d", area[2:4]): - emsg = "have two numbers after the state code" - sys.stderr.write(f": area '{area}' must {emsg}\n") - return False - if is_faux_area: - max_cdnum = 99 - else: - max_cdnum = state_info[s_c.upper()][cd_census_year] - cdnum = int(area[2:4]) - if max_cdnum <= 1: - if cdnum != 0: - sys.stderr.write( - f": use area '{s_c}00' for this one-district state\n" - ) - return False - else: # if max_cdnum >= 2 - if cdnum > max_cdnum: - sys.stderr.write(f": cd number '{cdnum}' exceeds {max_cdnum}\n") - return False - # ... check district area assumption character which is optional - if len_area_str == 5: - assump_char = area[4:5] - if not re.match(r"[A-Z]", assump_char): - emsg = "assumption character that is not an upper-case letter" - sys.stderr.write(f": area '{area}' has {emsg}\n") - return False - return True - - -def all_taxcalc_variables(): - """ - Return all read and needed calc Tax-Calculator variables in pd.DataFrame. - """ +CACHED_ALLVARS_PATH = STORAGE_FOLDER / "output" / "cached_allvars.csv" + +# Tax-Calculator output variables to load from cached_allvars for targeting +CACHED_TC_OUTPUTS = [ + "c18300", + "c04470", + "c02500", + "c19200", + "c19700", + "eitc", + "ctc_total", +] + +# Default solver parameters +AREA_CONSTRAINT_TOL = 0.005 +AREA_SLACK_PENALTY = 1e6 +AREA_MAX_ITER = 2000 +AREA_MULTIPLIER_MIN = 0.0 +AREA_MULTIPLIER_MAX = 25.0 + +# Default target/weight directories for states +STATE_TARGET_DIR = AREAS_FOLDER / "targets" / "states" +STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states" + + +def _load_taxcalc_data(): + """Load TMD data with cached AGI and selected Tax-Calculator outputs.""" vdf = pd.read_csv(INFILE_PATH) vdf["c00100"] = np.load(TAXCALC_AGI_CACHE) + if CACHED_ALLVARS_PATH.exists(): + allvars = pd.read_csv(CACHED_ALLVARS_PATH, usecols=CACHED_TC_OUTPUTS) + for col in CACHED_TC_OUTPUTS: + if col in allvars.columns: + vdf[col] = allvars[col].values + # Synthetic combined variable for net capital gains targeting + if "p22250" in vdf.columns and "p23250" in vdf.columns: + vdf["capgains_net"] = vdf["p22250"] + vdf["p23250"] assert np.all(vdf.s006 > 0), "Not all weights are positive" return vdf -def prepared_data(area: str, vardf: pd.DataFrame): +def _read_params(area, out, target_dir=None): + """Read optional area-specific parameters YAML file.""" + if target_dir is None: + target_dir = STATE_TARGET_DIR + params = {} + pfile = f"{area}_params.yaml" + params_path = target_dir / pfile + if params_path.exists(): + with open(params_path, "r", encoding="utf-8") as f: + params = yaml.safe_load(f.read()) + exp_params = [ + "target_ratio_tolerance", + "dump_all_target_deviations", + "delta_init_value", + "delta_max_loops", + "iprint", + "constraint_tol", + "slack_penalty", + "max_iter", + "multiplier_min", + "multiplier_max", + ] + act_params = list(params.keys()) + all_ok = len(set(act_params)) == len(act_params) + for param in act_params: + if param not in exp_params: + all_ok = False + out.write( + f"WARNING: {pfile} parameter" f" {param} is not expected\n" + ) + if not all_ok: + out.write(f"IGNORING CONTENTS OF {pfile}\n") + params = {} + elif params: + out.write(f"USING CUSTOMIZED PARAMETERS IN {pfile}\n") + return params + + +def _build_constraint_matrix(area, vardf, out, target_dir=None): """ - Construct numpy 2-D target matrix and 1-D target array for - specified area using specified vardf. Also, compute initial - weights scaling factor for specified area. Return all three - as a tuple. + Build constraint matrix B and target array from area targets CSV. + + Returns (B_csc, targets, target_labels, pop_share) where: + - B_csc is a sparse matrix (n_targets x n_records) + - targets is 1-D array of target values + - target_labels is a list of descriptive strings + - pop_share is the area's population share of national population """ + if target_dir is None: + target_dir = STATE_TARGET_DIR national_population = (vardf.s006 * vardf.XTOT).sum() numobs = len(vardf) - targets_file = AREAS_FOLDER / "targets" / f"{area}_targets.csv" + targets_file = target_dir / f"{area}_targets.csv" tdf = pd.read_csv(targets_file, comment="#") - tm_tuple = () - ta_list = [] - row_num = 1 - initial_weights_scale = None - for row in tdf.itertuples(index=False): - row_num += 1 - line = f"{area}:L{row_num}" - # construct target amount for this row - unscaled_target = row.target - if unscaled_target == 0: - unscaled_target = 1.0 - scale = 1.0 / unscaled_target - scaled_target = unscaled_target * scale - ta_list.append(scaled_target) - # confirm that row_num 2 contains the area population target - if row_num == 2: - bool_list = [ - row.varname == "XTOT", - row.count == 0, - row.scope == 0, - row.agilo < -8e99, - row.agihi > 8e99, - row.fstatus == 0, - ] - assert all( - bool_list - ), f"{line} does not contain the area population target" - initial_weights_scale = row.target / national_population + + targets_list = [] + labels_list = [] + columns = [] + pop_share = None + + for row_idx, row in enumerate(tdf.itertuples(index=False)): + line = f"{area}:target{row_idx + 1}" + + # extract target value + target_val = row.target + targets_list.append(target_val) + + # build label + label = ( + f"{row.varname}" + f"/cnt={row.count}" + f"/scope={row.scope}" + f"/agi=[{row.agilo},{row.agihi})" + f"/fs={row.fstatus}" + ) + labels_list.append(label) + + # first row must be XTOT population target + if row_idx == 0: + assert row.varname == "XTOT", f"{line}: first target must be XTOT" + assert row.count == 0 and row.scope == 0 + assert row.agilo < -8e99 and row.agihi > 8e99 + assert row.fstatus == 0 + pop_share = row.target / national_population + out.write( + f"pop_share = {row.target:.0f}" + f" / {national_population:.0f}" + f" = {pop_share:.6f}\n" + ) + # construct variable array for this target - assert ( - row.count >= 0 and row.count <= 4 - ), f"count value {row.count} not in [0,4] range on {line}" - if row.count == 0: # tabulate $ variable amount - unmasked_varray = vardf[row.varname].astype(float) - elif row.count == 1: # count units with any variable amount - unmasked_varray = (vardf[row.varname] > -np.inf).astype(float) - elif row.count == 2: # count only units with non-zero variable amount - unmasked_varray = (vardf[row.varname] != 0).astype(float) - elif row.count == 3: # count only units with positive variable amount - unmasked_varray = (vardf[row.varname] > 0).astype(float) - else: # count only units with negative variable amount (row.count==4) - unmasked_varray = (vardf[row.varname] < 0).astype(float) - mask = np.ones(numobs, dtype=int) - assert ( - row.scope >= 0 and row.scope <= 2 - ), f"scope value {row.scope} not in [0,2] range on {line}" + assert 0 <= row.count <= 4, f"count {row.count} not in [0,4] on {line}" + if row.count == 0: + var_array = vardf[row.varname].astype(float).values + elif row.count == 1: + var_array = np.ones(numobs, dtype=float) + elif row.count == 2: + var_array = (vardf[row.varname] != 0).astype(float).values + elif row.count == 3: + var_array = (vardf[row.varname] > 0).astype(float).values + else: + var_array = (vardf[row.varname] < 0).astype(float).values + + # construct mask + mask = np.ones(numobs, dtype=float) + assert 0 <= row.scope <= 2, f"scope {row.scope} not in [0,2] on {line}" if row.scope == 1: - mask *= vardf.data_source == 1 # PUF records + mask *= (vardf.data_source == 1).values.astype(float) elif row.scope == 2: - mask *= vardf.data_source == 0 # CPS records + mask *= (vardf.data_source == 0).values.astype(float) + in_agi_bin = (vardf.c00100 >= row.agilo) & (vardf.c00100 < row.agihi) - mask *= in_agi_bin + mask *= in_agi_bin.values.astype(float) + assert ( - row.fstatus >= 0 and row.fstatus <= 5 - ), f"fstatus value {row.fstatus} not in [0,5] range on {line}" + 0 <= row.fstatus <= 5 + ), f"fstatus {row.fstatus} not in [0,5] on {line}" if row.fstatus > 0: - mask *= vardf.MARS == row.fstatus - scaled_masked_varray = mask * unmasked_varray * scale - tm_tuple = tm_tuple + (scaled_masked_varray,) - # construct target matrix and target array and return as tuple - scale_factor = 1.0 # as high as 1e9 works just fine - target_matrix = np.vstack(tm_tuple).T * scale_factor - target_array = np.array(ta_list) * scale_factor - return ( - target_matrix, - target_array, - initial_weights_scale, - ) + mask *= (vardf.MARS == row.fstatus).values.astype(float) + # A[j,i] = mask * var_array (data values for this constraint) + columns.append(mask * var_array) -def target_misses(wght, target_matrix, target_array): - """ - Return number of target misses for the specified weight array and a - string containing size of each actual/expect target miss as a tuple. - """ - actual = np.dot(wght, target_matrix) - tratio = actual / target_array - tol = PARAMS.get("target_ratio_tolerance", TARGET_RATIO_TOLERANCE) - lob = 1.0 - tol - hib = 1.0 + tol - num = ((tratio < lob) | (tratio >= hib)).sum() - mstr = "" - if num > 0: - for tnum, ratio in enumerate(tratio): - if ratio < lob or ratio >= hib: - mstr += ( - f" ::::TARGET{(tnum + 1):03d}:ACT/EXP,lob,hib=" - f" {ratio:.6f} {lob:.6f} {hib:.6f}\n" - ) - return (num, mstr) + assert pop_share is not None, "XTOT target not found" + # A matrix: n_records x n_targets (each column is a constraint) + A_dense = np.column_stack(columns) -def target_rmse(wght, target_matrix, target_array, out, delta=None): - """ - Return RMSE of the target deviations given specified arguments. - """ - act = np.dot(wght, target_matrix) - act_minus_exp = act - target_array - ratio = act / target_array - min_ratio = np.min(ratio) - max_ratio = np.max(ratio) - dump = PARAMS.get("dump_all_target_deviations", DUMP_ALL_TARGET_DEVIATIONS) - if dump: - for tnum, ratio_ in enumerate(ratio): + # B = diag(w0) @ A where w0 = pop_share * s006 + w0 = pop_share * vardf.s006.values + B_dense = w0[:, np.newaxis] * A_dense + + # convert to sparse (n_targets x n_records) for Clarabel + B_csc = csc_matrix(B_dense.T) + + targets_arr = np.array(targets_list) + return B_csc, targets_arr, labels_list, pop_share + + +def _drop_impossible_targets(B_csc, targets, labels, out): + """Drop targets where all constraint matrix values are zero.""" + col_sums = np.abs(np.asarray(B_csc.sum(axis=1))).ravel() + all_zero = col_sums == 0 + if all_zero.any(): + n_drop = int(all_zero.sum()) + for i in np.where(all_zero)[0]: out.write( - f"TARGET{(tnum + 1):03d}:ACT-EXP,ACT/EXP= " - f"{act_minus_exp[tnum]:16.9e}, {ratio_:.3f}\n" + f"DROPPING impossible target" f" (all zeros): {labels[i]}\n" ) - # show distribution of target ratios - tol = PARAMS.get("target_ratio_tolerance", TARGET_RATIO_TOLERANCE) - bins = [ - -np.inf, - 0.0, - 0.4, - 0.8, - 0.9, - 0.99, - 1.0 - tol, - 1.0 + tol, - 1.01, - 1.1, - 1.2, - 1.6, - 2.0, - 3.0, - 4.0, - 5.0, - np.inf, - ] - tot = ratio.size - out.write(f"DISTRIBUTION OF TARGET ACT/EXP RATIOS (n={tot}):\n") - if delta is not None: - out.write(f" with REGULARIZATION_DELTA= {delta:e}\n") - header = ( - "low bin ratio high bin ratio" - " bin # cum # bin % cum %\n" - ) - out.write(header) - cutout = pd.cut(ratio, bins, right=False, precision=6) - count = pd.Series(cutout).value_counts().sort_index().to_dict() - cum = 0 - for interval, num in count.items(): - cum += num - if cum == 0: - continue - line = ( - f">={interval.left:13.6f}, <{interval.right:13.6f}:" - f" {num:6d} {cum:6d} {num / tot:7.2%} {cum / tot:7.2%}\n" + keep = ~all_zero + B_dense = B_csc.toarray() + B_csc = csc_matrix(B_dense[keep, :]) + targets = targets[keep] + labels = [lab for lab, k in zip(labels, keep) if k] + out.write( + f"Dropped {n_drop} impossible targets," + f" {len(targets)} remaining\n" ) - out.write(line) - if cum == tot: - break - # write minimum and maximum ratio values - line = f"MINIMUM VALUE OF TARGET ACT/EXP RATIO = {min_ratio:.3f}\n" - out.write(line) - line = f"MAXIMUM VALUE OF TARGET ACT/EXP RATIO = {max_ratio:.3f}\n" - out.write(line) - # return RMSE of ACT-EXP targets - return np.sqrt(np.mean(np.square(act_minus_exp))) - - -def objective_function(x, *args): + return B_csc, targets, labels + + +def _solve_area_qp( # pylint: disable=unused-argument + B_csc, + targets, + labels, + n_records, + constraint_tol=AREA_CONSTRAINT_TOL, + slack_penalty=AREA_SLACK_PENALTY, + max_iter=AREA_MAX_ITER, + multiplier_min=AREA_MULTIPLIER_MIN, + multiplier_max=AREA_MULTIPLIER_MAX, + weight_penalty=1.0, + out=None, +): """ - Objective function for minimization. - Search for NOTE in this file for methodological details. - https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf#page=320 + Solve the area reweighting QP using Clarabel. + + Parameters + ---------- + weight_penalty : float + Penalty weight on (x-1)^2 relative to constraint + slack. Higher values keep multipliers closer to 1.0 + at the cost of more target violations. + + Returns (x_opt, s_lo, s_hi, info_dict). """ - A, b, delta = args # A is a jax sparse matrix - ssq_target_deviations = jnp.sum(jnp.square(A @ x - b)) - ssq_weight_deviations = jnp.sum(jnp.square(x - 1.0)) - return ssq_target_deviations + delta * ssq_weight_deviations + if out is None: + out = sys.stdout + + m = len(targets) + n_total = n_records + 2 * m # x + s_lo + s_hi + + # constraint bounds: t*(1-eps) <= Bx <= t*(1+eps) + abs_targets = np.abs(targets) + tol_band = abs_targets * constraint_tol + cl = targets - tol_band + cu = targets + tol_band + + # diagonal Hessian: 2*alpha for x, 2*M for slacks + hess_diag = np.empty(n_total) + hess_diag[:n_records] = 2.0 * weight_penalty + hess_diag[n_records:] = 2.0 * slack_penalty + P = spdiags(hess_diag, format="csc") + + # linear term: -2*alpha for x + q = np.zeros(n_total) + q[:n_records] = -2.0 * weight_penalty + + # extended constraint matrix: [B | I_m | -I_m] + I_m = speye(m, format="csc") + A_full = hstack([B_csc, I_m, -I_m], format="csc") + + # constraint scaling for numerical conditioning + target_scale = np.maximum(np.abs(targets), 1.0) + D_inv = spdiags(1.0 / target_scale) + A_scaled = (D_inv @ A_full).tocsc() + cl_scaled = cl / target_scale + cu_scaled = cu / target_scale + + # variable bounds + var_lb = np.empty(n_total) + var_ub = np.empty(n_total) + var_lb[:n_records] = multiplier_min + var_ub[:n_records] = multiplier_max + var_lb[n_records:] = 0.0 + var_ub[n_records:] = 1e20 + + # Clarabel form: Ax + s = b, s in NonnegativeCone + I_n = speye(n_total, format="csc") + A_clar = vstack( + [A_scaled, -A_scaled, I_n, -I_n], + format="csc", + ) + b_clar = np.concatenate([cu_scaled, -cl_scaled, var_ub, -var_lb]) + + m_constraints = len(b_clar) + # pylint: disable=no-member + cones = [clarabel.NonnegativeConeT(m_constraints)] + + settings = clarabel.DefaultSettings() + # pylint: enable=no-member + settings.verbose = False + settings.max_iter = max_iter + settings.tol_gap_abs = 1e-7 + settings.tol_gap_rel = 1e-7 + settings.tol_feas = 1e-7 + + # solve + out.write("STARTING CLARABEL SOLVER...\n") + t_start = time.time() + solver = clarabel.DefaultSolver( # pylint: disable=no-member + P, q, A_clar, b_clar, cones, settings + ) + result = solver.solve() + elapsed = time.time() - t_start + status_str = str(result.status) + out.write( + f"Solver status: {status_str}\n" + f"Iterations: {result.iterations}\n" + f"Solve time: {elapsed:.2f}s\n" + ) -JIT_FVAL_AND_GRAD = jax.jit(jax.value_and_grad(objective_function)) + # extract solution + y_opt = np.array(result.x) + x_opt = y_opt[:n_records] + s_lo = y_opt[n_records : n_records + m] + s_hi = y_opt[n_records + m :] + x_opt = np.clip(x_opt, multiplier_min, multiplier_max) -def weight_ratio_distribution(ratio, delta, out): - """ - Print distribution of post-optimized to pre-optimized weight ratios. - """ + info = { + "status": status_str, + "iterations": result.iterations, + "solve_time": elapsed, + "clarabel_solve_time": result.solve_time, + } + + return x_opt, s_lo, s_hi, info + + +def _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out +): + """Print target accuracy diagnostics.""" + achieved = np.asarray(B_csc @ x_opt).ravel() + abs_errors = np.abs(achieved - targets) + rel_errors = abs_errors / np.maximum(np.abs(targets), 1.0) + + out.write(f"TARGET ACCURACY ({len(targets)} targets):\n") + out.write(f" mean |relative error|: {rel_errors.mean():.6f}\n") + out.write(f" max |relative error|: {rel_errors.max():.6f}\n") + + eps = 1e-9 + n_violated = int((rel_errors > constraint_tol + eps).sum()) + n_hit = len(targets) - n_violated + out.write( + f" targets hit: {n_hit}/{len(targets)}" + f" (tolerance: +/-{constraint_tol * 100:.1f}% + eps)\n" + ) + if n_violated > 0: + out.write(f" VIOLATED: {n_violated} targets\n") + worst_idx = np.argsort(rel_errors)[::-1] + for idx in worst_idx[: min(10, n_violated)]: + out.write( + f" {rel_errors[idx] * 100:7.3f}%" + f" | target={targets[idx]:15.0f}" + f" | achieved={achieved[idx]:15.0f}" + f" | {labels[idx]}\n" + ) + return n_violated + + +def _print_multiplier_diagnostics(x_opt, out): + """Print weight multiplier distribution diagnostics.""" + out.write("MULTIPLIER DISTRIBUTION:\n") + out.write( + f" min={x_opt.min():.6f}," + f" p5={np.percentile(x_opt, 5):.6f}," + f" median={np.median(x_opt):.6f}," + f" p95={np.percentile(x_opt, 95):.6f}," + f" max={x_opt.max():.6f}\n" + ) + out.write( + f" RMSE from 1.0:" f" {np.sqrt(np.mean((x_opt - 1.0) ** 2)):.6f}\n" + ) + + # distribution bins bins = [ 0.0, 1e-6, 0.1, - 0.2, 0.5, 0.8, - 0.85, 0.9, 0.95, 1.0, 1.05, 1.1, - 1.15, 1.2, + 1.5, 2.0, 5.0, - 1e1, - 1e2, - 1e3, - 1e4, - 1e5, + 10.0, + 100.0, np.inf, ] - tot = ratio.size - out.write(f"DISTRIBUTION OF AREA/US WEIGHT RATIO (n={tot}):\n") - out.write(f" with REGULARIZATION_DELTA= {delta:e}\n") - header = ( - "low bin ratio high bin ratio" - " bin # cum # bin % cum %\n" - ) - out.write(header) - cutout = pd.cut(ratio, bins, right=False, precision=6) - count = pd.Series(cutout).value_counts().sort_index().to_dict() - cum = 0 - for interval, num in count.items(): - cum += num - if cum == 0: - continue - line = ( - f">={interval.left:13.6f}, <{interval.right:13.6f}:" - f" {num:6d} {cum:6d} {num / tot:7.2%} {cum / tot:7.2%}\n" - ) - out.write(line) - if cum == tot: - break - # write RMSE of area/us weight ratio deviations from one - rmse = np.sqrt(np.mean(np.square(ratio - 1.0))) - line = f"RMSE OF AREA/US WEIGHT RATIO DEVIATIONS FROM ONE = {rmse:e}\n" - out.write(line) + tot = len(x_opt) + out.write(f" distribution (n={tot}):\n") + for i in range(len(bins) - 1): + count = int(((x_opt >= bins[i]) & (x_opt < bins[i + 1])).sum()) + if count > 0: + out.write( + f" [{bins[i]:10.4f}," + f" {bins[i + 1]:10.4f}):" + f" {count:7d}" + f" ({count / tot:7.2%})\n" + ) -# -- High-level logic of the script: +def _print_slack_diagnostics(s_lo, s_hi, targets, labels, out): + """Print elastic slack diagnostics.""" + total_slack = s_lo + s_hi + n_active = int(np.sum(total_slack > 1e-6)) + if n_active > 0: + out.write( + f"ELASTIC SLACK active on" + f" {n_active}/{len(targets)} constraints:\n" + ) + slack_idx = np.where(total_slack > 1e-6)[0] + for idx in slack_idx[np.argsort(total_slack[slack_idx])[::-1]][:20]: + out.write( + f" slack={total_slack[idx]:12.2f}" + f" | target={targets[idx]:15.0f}" + f" | {labels[idx]}\n" + ) + else: + out.write("ALL CONSTRAINTS SATISFIED WITHOUT SLACK\n") def create_area_weights_file( - area: str, - write_log: bool = True, - write_file: bool = True, + area, + write_log=True, + write_file=True, + target_dir=None, + weight_dir=None, ): """ - Create Tax-Calculator-style weights file for FIRST_YEAR through LAST_YEAR - for specified area using information in area targets CSV file. - Write log file if write_log=True, otherwise log is written to stdout. - Write weights file if write_file=True, otherwise just do calculations. + Create area weights file using Clarabel constrained QP solver. + + Returns 0 on success. """ - # remove any existing log or weights files - awpath = AREAS_FOLDER / "weights" / f"{area}_tmd_weights.csv.gz" + if target_dir is None: + target_dir = STATE_TARGET_DIR + if weight_dir is None: + weight_dir = STATE_WEIGHT_DIR + # ensure output directory exists + weight_dir.mkdir(parents=True, exist_ok=True) + + # remove any existing output files + awpath = weight_dir / f"{area}_tmd_weights.csv.gz" awpath.unlink(missing_ok=True) - logpath = AREAS_FOLDER / "weights" / f"{area}.log" + logpath = weight_dir / f"{area}.log" logpath.unlink(missing_ok=True) - # specify log output device + # set up output if write_log: out = open( # pylint: disable=consider-using-with - logpath, - "w", - encoding="utf-8", + logpath, "w", encoding="utf-8" ) else: out = sys.stdout + if write_file: - out.write(f"CREATING WEIGHTS FILE FOR AREA {area} ...\n") + out.write( + f"CREATING WEIGHTS FILE FOR AREA {area}" " (Clarabel solver) ...\n" + ) else: - out.write(f"DOING JUST WEIGHTS FILE CALCS FOR AREA {area} ...\n") + out.write( + f"DOING JUST CALCS FOR AREA {area}" " (Clarabel solver) ...\n" + ) - # configure jax library - jax.config.update("jax_platform_name", "cpu") # ignore GPU/TPU if present - jax.config.update("jax_enable_x64", True) # use double precision floats + # read optional parameters + params = _read_params(area, out, target_dir=target_dir) + constraint_tol = params.get( + "constraint_tol", + params.get("target_ratio_tolerance", AREA_CONSTRAINT_TOL), + ) + slack_penalty = params.get("slack_penalty", AREA_SLACK_PENALTY) + max_iter = params.get("max_iter", AREA_MAX_ITER) + multiplier_min = params.get("multiplier_min", AREA_MULTIPLIER_MIN) + multiplier_max = params.get("multiplier_max", AREA_MULTIPLIER_MAX) + + # load data and build constraint matrix + vdf = _load_taxcalc_data() + out.write(f"Loaded {len(vdf)} records\n") + out.write(f"National weight sum: {vdf.s006.sum():.0f}\n") + + B_csc, targets, labels, pop_share = _build_constraint_matrix( + area, vdf, out, target_dir=target_dir + ) + out.write( + f"Built constraint matrix:" + f" {B_csc.shape[0]} targets x" + f" {B_csc.shape[1]} records\n" + ) + out.write(f"Constraint tolerance: +/-{constraint_tol * 100:.1f}%\n") + out.write(f"Multiplier bounds: [{multiplier_min}, {multiplier_max}]\n") - # read optional parameters file - global PARAMS # pylint: disable=global-statement - PARAMS = {} - pfile = f"{area}_params.yaml" - params_file = AREAS_FOLDER / "targets" / pfile - if params_file.exists(): - with open(params_file, "r", encoding="utf-8") as paramfile: - PARAMS = yaml.safe_load(paramfile.read()) - exp_params = [ - "target_ratio_tolerance", - "iprint", - "dump_all_target_deviations", - "delta_init_value", - "delta_max_loops", - ] - if len(PARAMS) > len(exp_params): - nump = len(exp_params) - out.write( - f"ERROR: {pfile} must contain no more than {nump} parameters\n" - f"IGNORING CONTENTS OF {pfile}\n" - ) - PARAMS = {} - else: - act_params = list(PARAMS.keys()) - all_ok = True - if len(set(act_params)) != len(act_params): - all_ok = False - out.write(f"ERROR: {pfile} contains duplicate parameter\n") - for param in act_params: - if param not in exp_params: - all_ok = False - out.write( - f"ERROR: {pfile} parameter {param} is not expected\n" - ) - if not all_ok: - out.write(f"IGNORING CONTENTS OF {pfile}\n") - PARAMS = {} - if PARAMS: - out.write(f"USING CUSTOMIZED PARAMETERS IN {pfile}\n") + # drop impossible targets + B_csc, targets, labels = _drop_impossible_targets( + B_csc, targets, labels, out + ) - # construct variable matrix and target array and weights_scale - vdf = all_taxcalc_variables() - target_matrix, target_array, weights_scale = prepared_data(area, vdf) - wght_us = np.array(vdf.s006 * weights_scale) - out.write("INITIAL WEIGHTS STATISTICS:\n") - out.write(f"sum of national weights = {vdf.s006.sum():e}\n") - out.write(f"area weights_scale = {weights_scale:e}\n") - num_weights = len(wght_us) - num_targets = len(target_array) - out.write(f"USING {area}_targets.csv FILE WITH {num_targets} TARGETS\n") - tolstr = "ASSUMING TARGET_RATIO_TOLERANCE" - tol = PARAMS.get("target_ratio_tolerance", TARGET_RATIO_TOLERANCE) - out.write(f"{tolstr} = {tol:.6f}\n") - rmse = target_rmse(wght_us, target_matrix, target_array, out) - out.write(f"US_PROPORTIONALLY_SCALED_TARGET_RMSE= {rmse:.9e}\n") - density = np.count_nonzero(target_matrix) / target_matrix.size - out.write(f"target_matrix sparsity ratio = {(1.0 - density):.3f}\n") - - # optimize weight ratios by minimizing the sum of squared deviations - # of area-to-us weight ratios from one such that the optimized ratios - # hit all of the area targets - # - # NOTE: This is a bi-criterion minimization problem that can be - # solved using regularization methods. For background, - # consult Stephen Boyd and Lieven Vandenberghe, Convex - # Optimization, Cambridge University Press, 2004, in - # particular equation (6.9) on page 306 (see LINK below). - # Our problem is exactly the same as (6.9) except that - # we measure x deviations from one rather than from zero. - # LINK: https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf#page=320 - # - A_dense = (target_matrix * wght_us[:, np.newaxis]).T - A = BCOO.from_scipy_sparse(csr_matrix(A_dense)) # A is JAX sparse matrix - b = target_array - delta = PARAMS.get("delta_init_values", DELTA_INIT_VALUE) - max_loop = PARAMS.get("delta_max_loops", DELTA_MAX_LOOPS) - if max_loop > 1: - delta_loop_decrement = delta / (max_loop - 1) - else: - delta_loop_decrement = delta - out.write( - "OPTIMIZE WEIGHT RATIOS POSSIBLY IN A REGULARIZATION LOOP\n" - f" where initial REGULARIZATION DELTA value is {delta:e}\n" + # check what x=1 (population-proportional) achieves + n_records = B_csc.shape[1] + x_ones = np.ones(n_records) + achieved_x1 = np.asarray(B_csc @ x_ones).ravel() + rel_err_x1 = np.abs(achieved_x1 - targets) / np.maximum( + np.abs(targets), 1.0 ) - if max_loop > 1: - out.write(f" and there are at most {max_loop} REGULARIZATION LOOPS\n") - else: - out.write(" and there is only one REGULARIZATION LOOP\n") - out.write(f" and where target_matrix.shape= {target_matrix.shape}\n") - # ... specify possibly customized value of iprint for diagnostic callback - if write_log: - iprint = OPTIMIZE_IPRINT - else: - iprint = PARAMS.get("iprint", OPTIMIZE_IPRINT) - # ... define callback for diagnostic output to replace deprecated iprint - _iter_count = 0 - - def _diagnostic_callback(intermediate_result): - nonlocal _iter_count - _iter_count += 1 - if iprint > 0 and _iter_count % iprint == 0: - fval = intermediate_result.fun - out.write(f" iter {_iter_count}: fval={fval:.6e}\n") - - # ... reduce value of regularization delta if not all targets are hit - loop = 1 - delta = DELTA_INIT_VALUE - wght0 = np.ones(num_weights) - while loop <= max_loop: - time0 = time.time() - _iter_count = 0 - res = minimize( - fun=JIT_FVAL_AND_GRAD, # objective function and its gradient - x0=wght0, # initial guess for weight ratios - jac=True, # use gradient from JIT_FVAL_AND_GRAD function - args=(A, b, delta), # fixed arguments of objective function - method="L-BFGS-B", # use L-BFGS-B algorithm - bounds=Bounds(0.0, np.inf), # consider only non-negative weights - options={ - "maxiter": OPTIMIZE_MAXITER, - "ftol": OPTIMIZE_FTOL, - "gtol": OPTIMIZE_GTOL, - }, - callback=_diagnostic_callback if iprint > 0 else None, - ) - time1 = time.time() - wght_area = res.x * wght_us - misses, minfo = target_misses(wght_area, target_matrix, target_array) - if write_log: - out.write( - f" ::loop,delta,misses: {loop}" f" {delta:e} {misses}\n" - ) - else: - out.write( - f" ::loop,delta,misses,exectime(secs): {loop}" - f" {delta:e} {misses} {(time1 - time0):.1f}\n" - ) - if misses == 0 or res.success is False: - break # out of regularization delta loop - # show magnitude of target misses - out.write(minfo) - # prepare for next regularization delta loop - loop += 1 - delta -= delta_loop_decrement - if delta < 1e-20: - delta = 0.0 - # ... show regularization/optimization results - if write_log: - res_summary = ( - f">>> final delta loop" - f" iterations={res.nit} success={res.success}\n" - f">>> message: {res.message}\n" - f">>> L-BFGS-B optimized objective function value: {res.fun:.9e}\n" - ) - else: - res_summary = ( - f">>> final delta loop exectime= {(time1 - time0):.1f} secs" - f" iterations={res.nit} success={res.success}\n" - f">>> message: {res.message}\n" - f">>> L-BFGS-B optimized objective function value: {res.fun:.9e}\n" - ) - out.write(res_summary) - if OPTIMIZE_RESULTS: - out.write(">>> full final delta loop optimization results:\n") - for key in res.keys(): - out.write(f" {key}: {res.get(key)}\n") - wght_area = res.x * wght_us - misses, minfo = target_misses(wght_area, target_matrix, target_array) - out.write(f"AREA-OPTIMIZED_TARGET_MISSES= {misses}\n") - if misses > 0: - out.write(minfo) - rmse = target_rmse(wght_area, target_matrix, target_array, out, delta) - out.write(f"AREA-OPTIMIZED_TARGET_RMSE= {rmse:.9e}\n") - weight_ratio_distribution(res.x, delta, out) + out.write("BEFORE OPTIMIZATION (x=1, population-proportional):\n") + out.write(f" mean |relative error|: {rel_err_x1.mean():.6f}\n") + out.write(f" max |relative error|: {rel_err_x1.max():.6f}\n") + + # solve QP + x_opt, s_lo, s_hi, _info = _solve_area_qp( + B_csc, + targets, + labels, + n_records, + constraint_tol=constraint_tol, + slack_penalty=slack_penalty, + max_iter=max_iter, + multiplier_min=multiplier_min, + multiplier_max=multiplier_max, + out=out, + ) + + # diagnostics + _n_violated = _print_target_diagnostics( + x_opt, B_csc, targets, labels, constraint_tol, out + ) + _print_multiplier_diagnostics(x_opt, out) + _print_slack_diagnostics(s_lo, s_hi, targets, labels, out) if write_log: out.close() if not write_file: return 0 - # write area weights file extrapolating using national population forecast - # ... get population forecast - with open(POPFILE_PATH, "r", encoding="utf-8") as pfile: - pop = yaml.safe_load(pfile.read()) - # ... set FIRST_YEAR weights - weights = wght_area - # ... construct dictionary of scaled-up weights by year - wdict = {f"WT{FIRST_YEAR}": weights} + # write area weights file with population extrapolation + w0 = pop_share * vdf.s006.values + wght_area = x_opt * w0 + + with open(POPFILE_PATH, "r", encoding="utf-8") as pf: + pop = yaml.safe_load(pf.read()) + + wdict = {f"WT{FIRST_YEAR}": wght_area} cum_pop_growth = 1.0 for year in range(FIRST_YEAR + 1, LAST_YEAR + 1): annual_pop_growth = pop[year] / pop[year - 1] cum_pop_growth *= annual_pop_growth - wght = weights.copy() * cum_pop_growth - wdict[f"WT{year}"] = wght - # ... write weights to CSV-formatted file + wdict[f"WT{year}"] = wght_area * cum_pop_growth + wdf = pd.DataFrame.from_dict(wdict) - wdf.to_csv(awpath, index=False, float_format="%.5f", compression="gzip") + wdf.to_csv( + awpath, + index=False, + float_format="%.5f", + compression="gzip", + ) return 0 - - -if __name__ == "__main__": - if len(sys.argv) != 2: - sys.stderr.write( - "ERROR: exactly one command-line argument is required\n" - ) - sys.exit(1) - area_code = sys.argv[1] - if not valid_area(area_code): - sys.stderr.write(f"ERROR: {area_code} is not valid\n") - sys.exit(1) - tfile = f"{area_code}_targets.csv" - target_file = AREAS_FOLDER / "targets" / tfile - if not target_file.exists(): - sys.stderr.write( - f"ERROR: {tfile} file not in tmd/areas/targets folder\n" - ) - sys.exit(1) - RCODE = create_area_weights_file( - area_code, - write_log=False, - write_file=True, - ) - sys.exit(RCODE) diff --git a/tmd/areas/create_area_weights_clarabel.py b/tmd/areas/create_area_weights_clarabel.py deleted file mode 100644 index 1641237f..00000000 --- a/tmd/areas/create_area_weights_clarabel.py +++ /dev/null @@ -1,609 +0,0 @@ -""" -Clarabel-based constrained QP area weight optimization. - -Finds weight multipliers (x) for each record such that area-specific -weighted sums match area targets within a tolerance band, while -minimizing deviation from population-proportional weights. - -Formulation: - minimize sum((x_i - 1)^2) - subject to t_j*(1-eps) <= (Bx)_j <= t_j*(1+eps) [target bounds] - x_min <= x_i <= x_max [multiplier bounds] - -where: - x_i = area_weight_i / (pop_share * national_weight_i) - pop_share = area_population / national_population - B[j,i] = (pop_share * national_weight_i) * A[j,i] - - x_i = 1 means record i gets its population-proportional share. - The optimizer adjusts x_i to hit area-specific targets. - -Elastic slack variables handle infeasibility gracefully: - minimize sum((x_i - 1)^2) + M * sum(s^2) - subject to lb_j <= (Bx)_j + s_lo - s_hi <= ub_j, s >= 0 - -Follows the same QP construction as tmd/utils/reweight.py. -""" - -import sys -import time - -import clarabel -import numpy as np -import pandas as pd -import yaml -from scipy.sparse import ( - csc_matrix, - diags as spdiags, - eye as speye, - hstack, - vstack, -) - -from tmd.areas import AREAS_FOLDER -from tmd.imputation_assumptions import POPULATION_FILE, TAXYEAR -from tmd.storage import STORAGE_FOLDER - -FIRST_YEAR = TAXYEAR -LAST_YEAR = 2034 -INFILE_PATH = STORAGE_FOLDER / "output" / "tmd.csv.gz" -POPFILE_PATH = STORAGE_FOLDER / "input" / POPULATION_FILE -TAXCALC_AGI_CACHE = STORAGE_FOLDER / "output" / "cached_c00100.npy" -CACHED_ALLVARS_PATH = STORAGE_FOLDER / "output" / "cached_allvars.csv" - -# Tax-Calculator output variables to load from cached_allvars for targeting -CACHED_TC_OUTPUTS = [ - "c18300", - "c04470", - "c02500", - "c19200", - "c19700", - "eitc", - "ctc_total", -] - -# Default solver parameters -AREA_CONSTRAINT_TOL = 0.005 -AREA_SLACK_PENALTY = 1e6 -AREA_MAX_ITER = 2000 -AREA_MULTIPLIER_MIN = 0.0 -AREA_MULTIPLIER_MAX = 25.0 - -# Default target/weight directories for states -STATE_TARGET_DIR = AREAS_FOLDER / "targets" / "states" -STATE_WEIGHT_DIR = AREAS_FOLDER / "weights" / "states" - - -def _load_taxcalc_data(): - """Load TMD data with cached AGI and selected Tax-Calculator outputs.""" - vdf = pd.read_csv(INFILE_PATH) - vdf["c00100"] = np.load(TAXCALC_AGI_CACHE) - if CACHED_ALLVARS_PATH.exists(): - allvars = pd.read_csv(CACHED_ALLVARS_PATH, usecols=CACHED_TC_OUTPUTS) - for col in CACHED_TC_OUTPUTS: - if col in allvars.columns: - vdf[col] = allvars[col].values - # Synthetic combined variable for net capital gains targeting - if "p22250" in vdf.columns and "p23250" in vdf.columns: - vdf["capgains_net"] = vdf["p22250"] + vdf["p23250"] - assert np.all(vdf.s006 > 0), "Not all weights are positive" - return vdf - - -def _read_params(area, out, target_dir=None): - """Read optional area-specific parameters YAML file.""" - if target_dir is None: - target_dir = STATE_TARGET_DIR - params = {} - pfile = f"{area}_params.yaml" - params_path = target_dir / pfile - if params_path.exists(): - with open(params_path, "r", encoding="utf-8") as f: - params = yaml.safe_load(f.read()) - exp_params = [ - "target_ratio_tolerance", - "dump_all_target_deviations", - "delta_init_value", - "delta_max_loops", - "iprint", - "constraint_tol", - "slack_penalty", - "max_iter", - "multiplier_min", - "multiplier_max", - ] - act_params = list(params.keys()) - all_ok = len(set(act_params)) == len(act_params) - for param in act_params: - if param not in exp_params: - all_ok = False - out.write( - f"WARNING: {pfile} parameter" f" {param} is not expected\n" - ) - if not all_ok: - out.write(f"IGNORING CONTENTS OF {pfile}\n") - params = {} - elif params: - out.write(f"USING CUSTOMIZED PARAMETERS IN {pfile}\n") - return params - - -def _build_constraint_matrix(area, vardf, out, target_dir=None): - """ - Build constraint matrix B and target array from area targets CSV. - - Returns (B_csc, targets, target_labels, pop_share) where: - - B_csc is a sparse matrix (n_targets x n_records) - - targets is 1-D array of target values - - target_labels is a list of descriptive strings - - pop_share is the area's population share of national population - """ - if target_dir is None: - target_dir = STATE_TARGET_DIR - national_population = (vardf.s006 * vardf.XTOT).sum() - numobs = len(vardf) - targets_file = target_dir / f"{area}_targets.csv" - tdf = pd.read_csv(targets_file, comment="#") - - targets_list = [] - labels_list = [] - columns = [] - pop_share = None - - for row_idx, row in enumerate(tdf.itertuples(index=False)): - line = f"{area}:target{row_idx + 1}" - - # extract target value - target_val = row.target - targets_list.append(target_val) - - # build label - label = ( - f"{row.varname}" - f"/cnt={row.count}" - f"/scope={row.scope}" - f"/agi=[{row.agilo},{row.agihi})" - f"/fs={row.fstatus}" - ) - labels_list.append(label) - - # first row must be XTOT population target - if row_idx == 0: - assert row.varname == "XTOT", f"{line}: first target must be XTOT" - assert row.count == 0 and row.scope == 0 - assert row.agilo < -8e99 and row.agihi > 8e99 - assert row.fstatus == 0 - pop_share = row.target / national_population - out.write( - f"pop_share = {row.target:.0f}" - f" / {national_population:.0f}" - f" = {pop_share:.6f}\n" - ) - - # construct variable array for this target - assert 0 <= row.count <= 4, f"count {row.count} not in [0,4] on {line}" - if row.count == 0: - var_array = vardf[row.varname].astype(float).values - elif row.count == 1: - var_array = np.ones(numobs, dtype=float) - elif row.count == 2: - var_array = (vardf[row.varname] != 0).astype(float).values - elif row.count == 3: - var_array = (vardf[row.varname] > 0).astype(float).values - else: - var_array = (vardf[row.varname] < 0).astype(float).values - - # construct mask - mask = np.ones(numobs, dtype=float) - assert 0 <= row.scope <= 2, f"scope {row.scope} not in [0,2] on {line}" - if row.scope == 1: - mask *= (vardf.data_source == 1).values.astype(float) - elif row.scope == 2: - mask *= (vardf.data_source == 0).values.astype(float) - - in_agi_bin = (vardf.c00100 >= row.agilo) & (vardf.c00100 < row.agihi) - mask *= in_agi_bin.values.astype(float) - - assert ( - 0 <= row.fstatus <= 5 - ), f"fstatus {row.fstatus} not in [0,5] on {line}" - if row.fstatus > 0: - mask *= (vardf.MARS == row.fstatus).values.astype(float) - - # A[j,i] = mask * var_array (data values for this constraint) - columns.append(mask * var_array) - - assert pop_share is not None, "XTOT target not found" - - # A matrix: n_records x n_targets (each column is a constraint) - A_dense = np.column_stack(columns) - - # B = diag(w0) @ A where w0 = pop_share * s006 - w0 = pop_share * vardf.s006.values - B_dense = w0[:, np.newaxis] * A_dense - - # convert to sparse (n_targets x n_records) for Clarabel - B_csc = csc_matrix(B_dense.T) - - targets_arr = np.array(targets_list) - return B_csc, targets_arr, labels_list, pop_share - - -def _drop_impossible_targets(B_csc, targets, labels, out): - """Drop targets where all constraint matrix values are zero.""" - col_sums = np.abs(np.asarray(B_csc.sum(axis=1))).ravel() - all_zero = col_sums == 0 - if all_zero.any(): - n_drop = int(all_zero.sum()) - for i in np.where(all_zero)[0]: - out.write( - f"DROPPING impossible target" f" (all zeros): {labels[i]}\n" - ) - keep = ~all_zero - B_dense = B_csc.toarray() - B_csc = csc_matrix(B_dense[keep, :]) - targets = targets[keep] - labels = [lab for lab, k in zip(labels, keep) if k] - out.write( - f"Dropped {n_drop} impossible targets," - f" {len(targets)} remaining\n" - ) - return B_csc, targets, labels - - -def _solve_area_qp( # pylint: disable=unused-argument - B_csc, - targets, - labels, - n_records, - constraint_tol=AREA_CONSTRAINT_TOL, - slack_penalty=AREA_SLACK_PENALTY, - max_iter=AREA_MAX_ITER, - multiplier_min=AREA_MULTIPLIER_MIN, - multiplier_max=AREA_MULTIPLIER_MAX, - weight_penalty=1.0, - out=None, -): - """ - Solve the area reweighting QP using Clarabel. - - Parameters - ---------- - weight_penalty : float - Penalty weight on (x-1)^2 relative to constraint - slack. Higher values keep multipliers closer to 1.0 - at the cost of more target violations. - - Returns (x_opt, s_lo, s_hi, info_dict). - """ - if out is None: - out = sys.stdout - - m = len(targets) - n_total = n_records + 2 * m # x + s_lo + s_hi - - # constraint bounds: t*(1-eps) <= Bx <= t*(1+eps) - abs_targets = np.abs(targets) - tol_band = abs_targets * constraint_tol - cl = targets - tol_band - cu = targets + tol_band - - # diagonal Hessian: 2*alpha for x, 2*M for slacks - hess_diag = np.empty(n_total) - hess_diag[:n_records] = 2.0 * weight_penalty - hess_diag[n_records:] = 2.0 * slack_penalty - P = spdiags(hess_diag, format="csc") - - # linear term: -2*alpha for x - q = np.zeros(n_total) - q[:n_records] = -2.0 * weight_penalty - - # extended constraint matrix: [B | I_m | -I_m] - I_m = speye(m, format="csc") - A_full = hstack([B_csc, I_m, -I_m], format="csc") - - # constraint scaling for numerical conditioning - target_scale = np.maximum(np.abs(targets), 1.0) - D_inv = spdiags(1.0 / target_scale) - A_scaled = (D_inv @ A_full).tocsc() - cl_scaled = cl / target_scale - cu_scaled = cu / target_scale - - # variable bounds - var_lb = np.empty(n_total) - var_ub = np.empty(n_total) - var_lb[:n_records] = multiplier_min - var_ub[:n_records] = multiplier_max - var_lb[n_records:] = 0.0 - var_ub[n_records:] = 1e20 - - # Clarabel form: Ax + s = b, s in NonnegativeCone - I_n = speye(n_total, format="csc") - A_clar = vstack( - [A_scaled, -A_scaled, I_n, -I_n], - format="csc", - ) - b_clar = np.concatenate([cu_scaled, -cl_scaled, var_ub, -var_lb]) - - m_constraints = len(b_clar) - # pylint: disable=no-member - cones = [clarabel.NonnegativeConeT(m_constraints)] - - settings = clarabel.DefaultSettings() - # pylint: enable=no-member - settings.verbose = False - settings.max_iter = max_iter - settings.tol_gap_abs = 1e-7 - settings.tol_gap_rel = 1e-7 - settings.tol_feas = 1e-7 - - # solve - out.write("STARTING CLARABEL SOLVER...\n") - t_start = time.time() - solver = clarabel.DefaultSolver( # pylint: disable=no-member - P, q, A_clar, b_clar, cones, settings - ) - result = solver.solve() - elapsed = time.time() - t_start - - status_str = str(result.status) - out.write( - f"Solver status: {status_str}\n" - f"Iterations: {result.iterations}\n" - f"Solve time: {elapsed:.2f}s\n" - ) - - # extract solution - y_opt = np.array(result.x) - x_opt = y_opt[:n_records] - s_lo = y_opt[n_records : n_records + m] - s_hi = y_opt[n_records + m :] - - x_opt = np.clip(x_opt, multiplier_min, multiplier_max) - - info = { - "status": status_str, - "iterations": result.iterations, - "solve_time": elapsed, - "clarabel_solve_time": result.solve_time, - } - - return x_opt, s_lo, s_hi, info - - -def _print_target_diagnostics( - x_opt, B_csc, targets, labels, constraint_tol, out -): - """Print target accuracy diagnostics.""" - achieved = np.asarray(B_csc @ x_opt).ravel() - abs_errors = np.abs(achieved - targets) - rel_errors = abs_errors / np.maximum(np.abs(targets), 1.0) - - out.write(f"TARGET ACCURACY ({len(targets)} targets):\n") - out.write(f" mean |relative error|: {rel_errors.mean():.6f}\n") - out.write(f" max |relative error|: {rel_errors.max():.6f}\n") - - eps = 1e-9 - n_violated = int((rel_errors > constraint_tol + eps).sum()) - n_hit = len(targets) - n_violated - out.write( - f" targets hit: {n_hit}/{len(targets)}" - f" (tolerance: +/-{constraint_tol * 100:.1f}% + eps)\n" - ) - if n_violated > 0: - out.write(f" VIOLATED: {n_violated} targets\n") - worst_idx = np.argsort(rel_errors)[::-1] - for idx in worst_idx[: min(10, n_violated)]: - out.write( - f" {rel_errors[idx] * 100:7.3f}%" - f" | target={targets[idx]:15.0f}" - f" | achieved={achieved[idx]:15.0f}" - f" | {labels[idx]}\n" - ) - return n_violated - - -def _print_multiplier_diagnostics(x_opt, out): - """Print weight multiplier distribution diagnostics.""" - out.write("MULTIPLIER DISTRIBUTION:\n") - out.write( - f" min={x_opt.min():.6f}," - f" p5={np.percentile(x_opt, 5):.6f}," - f" median={np.median(x_opt):.6f}," - f" p95={np.percentile(x_opt, 95):.6f}," - f" max={x_opt.max():.6f}\n" - ) - out.write( - f" RMSE from 1.0:" f" {np.sqrt(np.mean((x_opt - 1.0) ** 2)):.6f}\n" - ) - - # distribution bins - bins = [ - 0.0, - 1e-6, - 0.1, - 0.5, - 0.8, - 0.9, - 0.95, - 1.0, - 1.05, - 1.1, - 1.2, - 1.5, - 2.0, - 5.0, - 10.0, - 100.0, - np.inf, - ] - tot = len(x_opt) - out.write(f" distribution (n={tot}):\n") - for i in range(len(bins) - 1): - count = int(((x_opt >= bins[i]) & (x_opt < bins[i + 1])).sum()) - if count > 0: - out.write( - f" [{bins[i]:10.4f}," - f" {bins[i + 1]:10.4f}):" - f" {count:7d}" - f" ({count / tot:7.2%})\n" - ) - - -def _print_slack_diagnostics(s_lo, s_hi, targets, labels, out): - """Print elastic slack diagnostics.""" - total_slack = s_lo + s_hi - n_active = int(np.sum(total_slack > 1e-6)) - if n_active > 0: - out.write( - f"ELASTIC SLACK active on" - f" {n_active}/{len(targets)} constraints:\n" - ) - slack_idx = np.where(total_slack > 1e-6)[0] - for idx in slack_idx[np.argsort(total_slack[slack_idx])[::-1]][:20]: - out.write( - f" slack={total_slack[idx]:12.2f}" - f" | target={targets[idx]:15.0f}" - f" | {labels[idx]}\n" - ) - else: - out.write("ALL CONSTRAINTS SATISFIED WITHOUT SLACK\n") - - -def create_area_weights_file_clarabel( - area, - write_log=True, - write_file=True, - target_dir=None, - weight_dir=None, -): - """ - Create area weights file using Clarabel constrained QP solver. - - Drop-in replacement for create_area_weights_file() in - create_area_weights.py. - - Returns 0 on success. - """ - if target_dir is None: - target_dir = STATE_TARGET_DIR - if weight_dir is None: - weight_dir = STATE_WEIGHT_DIR - # ensure output directory exists - weight_dir.mkdir(parents=True, exist_ok=True) - - # remove any existing output files - awpath = weight_dir / f"{area}_tmd_weights.csv.gz" - awpath.unlink(missing_ok=True) - logpath = weight_dir / f"{area}.log" - logpath.unlink(missing_ok=True) - - # set up output - if write_log: - out = open( # pylint: disable=consider-using-with - logpath, "w", encoding="utf-8" - ) - else: - out = sys.stdout - - if write_file: - out.write( - f"CREATING WEIGHTS FILE FOR AREA {area}" " (Clarabel solver) ...\n" - ) - else: - out.write( - f"DOING JUST CALCS FOR AREA {area}" " (Clarabel solver) ...\n" - ) - - # read optional parameters - params = _read_params(area, out, target_dir=target_dir) - constraint_tol = params.get( - "constraint_tol", - params.get("target_ratio_tolerance", AREA_CONSTRAINT_TOL), - ) - slack_penalty = params.get("slack_penalty", AREA_SLACK_PENALTY) - max_iter = params.get("max_iter", AREA_MAX_ITER) - multiplier_min = params.get("multiplier_min", AREA_MULTIPLIER_MIN) - multiplier_max = params.get("multiplier_max", AREA_MULTIPLIER_MAX) - - # load data and build constraint matrix - vdf = _load_taxcalc_data() - out.write(f"Loaded {len(vdf)} records\n") - out.write(f"National weight sum: {vdf.s006.sum():.0f}\n") - - B_csc, targets, labels, pop_share = _build_constraint_matrix( - area, vdf, out, target_dir=target_dir - ) - out.write( - f"Built constraint matrix:" - f" {B_csc.shape[0]} targets x" - f" {B_csc.shape[1]} records\n" - ) - out.write(f"Constraint tolerance: +/-{constraint_tol * 100:.1f}%\n") - out.write(f"Multiplier bounds: [{multiplier_min}, {multiplier_max}]\n") - - # drop impossible targets - B_csc, targets, labels = _drop_impossible_targets( - B_csc, targets, labels, out - ) - - # check what x=1 (population-proportional) achieves - n_records = B_csc.shape[1] - x_ones = np.ones(n_records) - achieved_x1 = np.asarray(B_csc @ x_ones).ravel() - rel_err_x1 = np.abs(achieved_x1 - targets) / np.maximum( - np.abs(targets), 1.0 - ) - out.write("BEFORE OPTIMIZATION (x=1, population-proportional):\n") - out.write(f" mean |relative error|: {rel_err_x1.mean():.6f}\n") - out.write(f" max |relative error|: {rel_err_x1.max():.6f}\n") - - # solve QP - x_opt, s_lo, s_hi, _info = _solve_area_qp( - B_csc, - targets, - labels, - n_records, - constraint_tol=constraint_tol, - slack_penalty=slack_penalty, - max_iter=max_iter, - multiplier_min=multiplier_min, - multiplier_max=multiplier_max, - out=out, - ) - - # diagnostics - _n_violated = _print_target_diagnostics( - x_opt, B_csc, targets, labels, constraint_tol, out - ) - _print_multiplier_diagnostics(x_opt, out) - _print_slack_diagnostics(s_lo, s_hi, targets, labels, out) - - if write_log: - out.close() - if not write_file: - return 0 - - # write area weights file with population extrapolation - w0 = pop_share * vdf.s006.values - wght_area = x_opt * w0 - - with open(POPFILE_PATH, "r", encoding="utf-8") as pf: - pop = yaml.safe_load(pf.read()) - - wdict = {f"WT{FIRST_YEAR}": wght_area} - cum_pop_growth = 1.0 - for year in range(FIRST_YEAR + 1, LAST_YEAR + 1): - annual_pop_growth = pop[year] / pop[year - 1] - cum_pop_growth *= annual_pop_growth - wdict[f"WT{year}"] = wght_area * cum_pop_growth - - wdf = pd.DataFrame.from_dict(wdict) - wdf.to_csv( - awpath, - index=False, - float_format="%.5f", - compression="gzip", - ) - - return 0 diff --git a/tmd/areas/make_all.py b/tmd/areas/make_all.py index 362b4ab9..a6fe4d7b 100644 --- a/tmd/areas/make_all.py +++ b/tmd/areas/make_all.py @@ -9,7 +9,6 @@ import time from multiprocessing import Pool from tmd.areas.create_area_weights import ( - valid_area, create_area_weights_file, ) from tmd.areas import AREAS_FOLDER @@ -58,9 +57,8 @@ def to_do_areas(): tpaths = sorted(list(tfolder.glob("*_targets.csv"))) for tpath in tpaths: area = tpath.name.split("_")[0] - if not valid_area(area): - print(f"Skipping invalid area name {area}") - continue # skip this area + if area.startswith("."): + continue # skip hidden files wpath = AREAS_FOLDER / "weights" / f"{area}_tmd_weights.csv.gz" if wpath.exists(): wtime = wpath.stat().st_mtime diff --git a/tmd/areas/quality_report.py b/tmd/areas/quality_report.py index 037352d9..1531cb2e 100644 --- a/tmd/areas/quality_report.py +++ b/tmd/areas/quality_report.py @@ -21,7 +21,7 @@ import numpy as np import pandas as pd -from tmd.areas.create_area_weights_clarabel import ( +from tmd.areas.create_area_weights import ( AREA_CONSTRAINT_TOL, STATE_WEIGHT_DIR, ) diff --git a/tmd/areas/solve_weights.py b/tmd/areas/solve_weights.py index 17e0bb28..ce8af3b0 100644 --- a/tmd/areas/solve_weights.py +++ b/tmd/areas/solve_weights.py @@ -28,7 +28,7 @@ import numpy as np import pandas as pd -from tmd.areas.create_area_weights_clarabel import ( +from tmd.areas.create_area_weights import ( AREA_MULTIPLIER_MAX, STATE_TARGET_DIR, STATE_WEIGHT_DIR, diff --git a/tmd/areas/sweep_params.py b/tmd/areas/sweep_params.py index 386d12d0..6098828e 100644 --- a/tmd/areas/sweep_params.py +++ b/tmd/areas/sweep_params.py @@ -20,7 +20,7 @@ import numpy as np import yaml -from tmd.areas.create_area_weights_clarabel import ( +from tmd.areas.create_area_weights import ( AREA_CONSTRAINT_TOL, AREA_MAX_ITER, AREA_MULTIPLIER_MIN, From d09954aaaf7c2c4453e20195ef5e332da2f34a2e Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 18:31:22 -0400 Subject: [PATCH 09/10] Remove old scipy solver test and expected results test_area_weights.py tested the old scipy L-BFGS-B solver which is now replaced by Clarabel. The Clarabel solver is tested by test_solve_weights.py (test_clarabel_solver_xx and 10 other tests). Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/expected_area_wgt_2021_data.yaml | 19 ------ tests/expected_area_wgt_2022_data.yaml | 19 ------ tests/test_area_weights.py | 84 -------------------------- 3 files changed, 122 deletions(-) delete mode 100644 tests/expected_area_wgt_2021_data.yaml delete mode 100644 tests/expected_area_wgt_2022_data.yaml delete mode 100644 tests/test_area_weights.py diff --git a/tests/expected_area_wgt_2021_data.yaml b/tests/expected_area_wgt_2021_data.yaml deleted file mode 100644 index 6a1edbc1..00000000 --- a/tests/expected_area_wgt_2021_data.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# variable statistics that are targeted in the area weights optimization: -popall: 33 # # million -e00300: 20 # $ billion -e00900: 30 # $ billion -e00200: 1000 # $ billion -e02000: 30 # $ billion -e02400: 60 # $ billion -c00100: 1200 # $ billion -agihic: 10 # # thousand -e00400: 2 # $ billion -e00600: 8 # $ billion -e00650: 7 # $ billion -e01700: 12 # $ billion -e02300: 10 # $ billion -e17500: 5 # $ billion -e18400: 10 # $ billion -e18500: 10 # $ billion -# variable statistics that are not targeted in the area weights optimization: -# ... none available in the faux xx area ... diff --git a/tests/expected_area_wgt_2022_data.yaml b/tests/expected_area_wgt_2022_data.yaml deleted file mode 100644 index 6a1edbc1..00000000 --- a/tests/expected_area_wgt_2022_data.yaml +++ /dev/null @@ -1,19 +0,0 @@ -# variable statistics that are targeted in the area weights optimization: -popall: 33 # # million -e00300: 20 # $ billion -e00900: 30 # $ billion -e00200: 1000 # $ billion -e02000: 30 # $ billion -e02400: 60 # $ billion -c00100: 1200 # $ billion -agihic: 10 # # thousand -e00400: 2 # $ billion -e00600: 8 # $ billion -e00650: 7 # $ billion -e01700: 12 # $ billion -e02300: 10 # $ billion -e17500: 5 # $ billion -e18400: 10 # $ billion -e18500: 10 # $ billion -# variable statistics that are not targeted in the area weights optimization: -# ... none available in the faux xx area ... diff --git a/tests/test_area_weights.py b/tests/test_area_weights.py deleted file mode 100644 index 58cba41d..00000000 --- a/tests/test_area_weights.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Tests of tmd/areas/create_area_weights.py script. -""" - -import yaml -import taxcalc as tc -from tmd.storage import STORAGE_FOLDER -from tmd.imputation_assumptions import TAXYEAR, CREDIT_CLAIMING -from tmd.areas import AREAS_FOLDER -from tmd.areas.create_area_weights import create_area_weights_file -from tests.conftest import create_tmd_records - -YEAR = TAXYEAR - - -def test_area_xx(tests_folder): - """ - Optimize national weights for faux xx area using the faux xx area targets - and compare actual Tax-Calculator results with expected results when - using area weights along with national input data and growfactors. - """ - rc = create_area_weights_file("xx", write_log=False, write_file=True) - assert rc == 0, "create_areas_weights_file has non-zero return code" - # compare actual vs expected results for faux area xx - # ... instantiate Tax-Calculator object for area - pol = tc.Policy() - pol.implement_reform(CREDIT_CLAIMING) - rec = create_tmd_records( - data_path=STORAGE_FOLDER / "output" / "tmd.csv.gz", - weights_path=AREAS_FOLDER / "weights" / "xx_tmd_weights.csv.gz", - growfactors_path=STORAGE_FOLDER / "output" / "tmd_growfactors.csv", - ) - sim = tc.Calculator(policy=pol, records=rec) - # ... calculate tax variables for YEAR - sim.advance_to_year(YEAR) - sim.calc_all() - vdf = sim.dataframe([], all_vars=True) - # ... calculate actual results and store in act dictionary - wght = vdf.s006 * (vdf.data_source == 1) # PUF weights - act = { - "popall": (vdf.s006 * vdf.XTOT).sum() * 1e-6, - "e00300": (wght * vdf.e00300).sum() * 1e-9, - "e00900": (wght * vdf.e00900).sum() * 1e-9, - "e00200": (wght * vdf.e00200).sum() * 1e-9, - "e02000": (wght * vdf.e02000).sum() * 1e-9, - "e02400": (wght * vdf.e02400).sum() * 1e-9, - "c00100": (wght * vdf.c00100).sum() * 1e-9, - "agihic": (wght * (vdf.c00100 >= 1e6)).sum() * 1e-3, - "e00400": (wght * vdf.e00400).sum() * 1e-9, - "e00600": (wght * vdf.e00600).sum() * 1e-9, - "e00650": (wght * vdf.e00650).sum() * 1e-9, - "e01700": (wght * vdf.e01700).sum() * 1e-9, - "e02300": (wght * vdf.e02300).sum() * 1e-9, - "e17500": (wght * vdf.e17500).sum() * 1e-9, - "e18400": (wght * vdf.e18400).sum() * 1e-9, - "e18500": (wght * vdf.e18500).sum() * 1e-9, - } - # ... read expected results into exp dictionary - exp_path = tests_folder / f"expected_area_wgt_{TAXYEAR}_data.yaml" - with open(exp_path, "r", encoding="utf-8") as efile: - exp = yaml.safe_load(efile.read()) - # compare actual with expected results - default_rtol = 0.005 - rtol = { - "c00100": 0.008, - "e00200": 0.008, - } - if set(act.keys()) != set(exp.keys()): - print("sorted(act.keys())=", sorted(act.keys())) - print("sorted(exp.keys())=", sorted(exp.keys())) - raise ValueError("act.keys() != exp.keys()") - emsg = "" - for res in exp.keys(): - reldiff = act[res] / exp[res] - 1 - reltol = rtol.get(res, default_rtol) - ok = abs(reldiff) < reltol - if not ok: - emsg += ( - f"FAIL:res,act,exp,rdiff,rtol= {res} {act[res]:.5f}" - f" {exp[res]:.5f} {reldiff:.4f} {reltol:.4f}\n" - ) - if emsg: - print(emsg) - raise ValueError("ACT vs EXP diffs in test_area_weights") From d16a9a927e78d33794d61b353536dd9bfc10317d Mon Sep 17 00:00:00 2001 From: Don Boyd Date: Thu, 19 Mar 2026 18:43:15 -0400 Subject: [PATCH 10/10] Remove Clarabel from test function names Rename test_clarabel_solver_xx to test_solver_xx since the module name no longer contains "clarabel". Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_solve_weights.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_solve_weights.py b/tests/test_solve_weights.py index a30165b3..6ec1ca27 100644 --- a/tests/test_solve_weights.py +++ b/tests/test_solve_weights.py @@ -1,8 +1,8 @@ """ -Tests for the Clarabel-based state weight solver pipeline. +Tests for the state weight solver pipeline. Tests: - - Clarabel solver on faux xx area + - Solver on faux xx area - Quality report log parser - CLI scope parsing - Batch area filtering @@ -38,9 +38,9 @@ # --- Solver test on faux xx area --- -def test_clarabel_solver_xx(): +def test_solver_xx(): """ - Solve faux xx area with Clarabel and verify targets are hit. + Solve faux xx area and verify targets are hit. """ # xx targets are in the flat targets/ directory target_dir = AREAS_FOLDER / "targets" @@ -78,9 +78,9 @@ def test_clarabel_solver_xx(): assert "TARGET ACCURACY" in log_text -def test_clarabel_solver_xx_targets_hit(): +def test_solver_xx_targets_hit(): """ - Verify the Clarabel solver actually hits the xx area targets + Verify the solver actually hits the xx area targets within the constraint tolerance. """ target_dir = AREAS_FOLDER / "targets"