From 598741f06fb7ed65b59af9ed69ccea94189c7cee Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Mar 2026 20:26:02 +0100 Subject: [PATCH 1/2] refactor: Consolidate computation_modules with CountryConfig from library Fixes #138 - Replace 6 duplicate UK/US module functions with single config-driven functions - Replace UK_MODULE_DISPATCH + US_MODULE_DISPATCH with single MODULE_DISPATCH - Add get_dispatch_for_country() to filter dispatch by MODULE_REGISTRY - Use compute_decile_impacts(), compute_budget_summary(), compute_program_statistics() from policyengine library instead of inline implementations - run_modules() now accepts country_id and resolves config + dispatch internally - Remove model_rebuild() calls (fixed in library __init__.py) - Remove hardcoded program/variable lists (now in US_CONFIG/UK_CONFIG) - Update tests to match new single-dispatch architecture (29/29 pass) Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/api/analysis.py | 10 +- .../api/computation_modules.py | 490 +++++------------- tests/test_computation_modules.py | 292 ++++++----- 3 files changed, 300 insertions(+), 492 deletions(-) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c83c950..d36eeba 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -897,10 +897,10 @@ def build_dynamic(dynamic_id): pe_reform_sim.ensure() # Run computation modules - from policyengine_api.api.computation_modules import UK_MODULE_DISPATCH, run_modules + from policyengine_api.api.computation_modules import run_modules run_modules( - dispatch=UK_MODULE_DISPATCH, + country_id="uk", modules=modules, pe_baseline_sim=pe_baseline_sim, pe_reform_sim=pe_reform_sim, @@ -908,7 +908,6 @@ def build_dynamic(dynamic_id): reform_sim_id=reform_sim.id, report_id=report.id, session=session, - country_id="uk", ) # Mark completed @@ -1082,10 +1081,10 @@ def build_dynamic(dynamic_id): pe_reform_sim.ensure() # Run computation modules - from policyengine_api.api.computation_modules import US_MODULE_DISPATCH, run_modules + from policyengine_api.api.computation_modules import run_modules run_modules( - dispatch=US_MODULE_DISPATCH, + country_id="us", modules=modules, pe_baseline_sim=pe_baseline_sim, pe_reform_sim=pe_reform_sim, @@ -1093,7 +1092,6 @@ def build_dynamic(dynamic_id): reform_sim_id=reform_sim.id, report_id=report.id, session=session, - country_id="us", ) # Mark completed diff --git a/src/policyengine_api/api/computation_modules.py b/src/policyengine_api/api/computation_modules.py index 4c6c52d..60931b3 100644 --- a/src/policyengine_api/api/computation_modules.py +++ b/src/policyengine_api/api/computation_modules.py @@ -3,10 +3,10 @@ Each function computes a single module's results and writes DB records. They share a common signature pattern: (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, - report_id, session, **kwargs) -> None + report_id, session, config) -> None -run_modules() passes country_id as a kwarg. Modules that need it (e.g. -compute_decile_module) accept it explicitly; others accept **_kwargs. +run_modules() resolves the country's dispatch table and passes a +CountryConfig from the policyengine library to each module function. Used by _run_local_economy_comparison_uk/us to run modules selectively. """ @@ -17,6 +17,7 @@ from sqlmodel import Session +from policyengine_api.api.module_registry import MODULE_REGISTRY from policyengine_api.models import ( BudgetSummary, CongressionalDistrictImpact, @@ -30,14 +31,30 @@ ) # --------------------------------------------------------------------------- -# Shared modules (UK + US) +# Country configs — imported from the policyengine library # --------------------------------------------------------------------------- +# Lazy-loaded to avoid importing heavy policyengine modules at import time. +# Callers should use get_country_config() instead. +_COUNTRY_CONFIGS: dict | None = None -DECILE_INCOME_VARIABLE: dict[str, str] = { - "us": "household_net_income", - "uk": "equiv_household_net_income", -} + +def get_country_config(country_id: str): + """Return the CountryConfig for the given country_id ('uk' or 'us').""" + global _COUNTRY_CONFIGS + if _COUNTRY_CONFIGS is None: + from policyengine.outputs.country_config import UK_CONFIG, US_CONFIG + + _COUNTRY_CONFIGS = {"us": US_CONFIG, "uk": UK_CONFIG} + config = _COUNTRY_CONFIGS.get(country_id) + if config is None: + raise ValueError(f"No CountryConfig for country '{country_id}'") + return config + + +# --------------------------------------------------------------------------- +# Config-driven modules (shared UK + US) +# --------------------------------------------------------------------------- def compute_decile_module( @@ -47,26 +64,17 @@ def compute_decile_module( reform_sim_id: UUID, report_id: UUID, session: Session, - country_id: str = "", + config, ) -> None: - """Compute income decile impacts (1-10).""" - from policyengine.outputs import DecileImpact as PEDecileImpact - - if country_id not in DECILE_INCOME_VARIABLE: - raise ValueError( - f"No decile income variable configured for country '{country_id}'" - ) + """Compute income decile impacts (1-10) using config.income_variable.""" + from policyengine.outputs.decile_impact import compute_decile_impacts - income_variable = DECILE_INCOME_VARIABLE[country_id] - - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - income_variable=income_variable, - ) - di.run() + results = compute_decile_impacts( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable=config.income_variable, + ) + for di in results.outputs: record = DecileImpact( baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, @@ -93,7 +101,7 @@ def compute_intra_decile_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute intra-decile income change distribution (5 bands).""" from policyengine.outputs.intra_decile_impact import ( @@ -103,8 +111,7 @@ def compute_intra_decile_module( results = pe_compute_intra_decile( baseline_simulation=pe_baseline_sim, reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", + income_variable=config.income_variable, ) for r in results.outputs: record = IntraDecileImpact( @@ -121,61 +128,35 @@ def compute_intra_decile_module( session.add(record) -# --------------------------------------------------------------------------- -# UK-specific modules -# --------------------------------------------------------------------------- - - -def compute_program_statistics_module_uk( +def compute_program_statistics_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK programme statistics.""" - from policyengine.core import Simulation as PESimulation - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) + """Compute program statistics using config.programs.""" + from policyengine.outputs.program_statistics import compute_program_statistics - PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - except KeyError: - import logfire - - logfire.warning(f"Program variable not found, skipping: {prog_name}") - continue + results = compute_program_statistics( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + programs=config.programs, + ) + for ps in results.outputs: + # Detect the name field (US: program_name, UK: programme_name) + prog_name = getattr(ps, "program_name", None) or getattr( + ps, "programme_name", None + ) record = ProgramStatistics( baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, report_id=report_id, program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], + entity=ps.entity, + is_tax=ps.is_tax, baseline_total=ps.baseline_total, reform_total=ps.reform_total, change=ps.change, @@ -187,32 +168,49 @@ def compute_program_statistics_module_uk( session.add(record) -def compute_poverty_module_uk( +def compute_poverty_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK poverty rates (overall, by age, by gender).""" - from policyengine.outputs.poverty import ( - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - calculate_uk_poverty_rates, - ) + """Compute poverty rates using config.country_id to select calculators.""" + if config.country_id == "uk": + from policyengine.outputs.poverty import ( + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + calculate_uk_poverty_rates, + ) + + calculators = [ + calculate_uk_poverty_rates, + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + ] + else: + from policyengine.outputs.poverty import ( + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + calculate_us_poverty_rates, + ) + + calculators = [ + calculate_us_poverty_rates, + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + ] sim_pairs = [ (pe_baseline_sim, baseline_sim_id), (pe_reform_sim, reform_sim_id), ] - for calculator in [ - calculate_uk_poverty_rates, - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - ]: + for calculator in calculators: for pe_sim, db_sim_id in sim_pairs: results = calculator(pe_sim) for pov in results.outputs: @@ -229,23 +227,26 @@ def compute_poverty_module_uk( session.add(record) -def compute_inequality_module_uk( +def compute_inequality_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK inequality metrics.""" - from policyengine.outputs.inequality import calculate_uk_inequality + """Compute inequality metrics using config.country_id to select calculator.""" + if config.country_id == "uk": + from policyengine.outputs.inequality import calculate_uk_inequality as calc_fn + else: + from policyengine.outputs.inequality import calculate_us_inequality as calc_fn for pe_sim, db_sim_id in [ (pe_baseline_sim, baseline_sim_id), (pe_reform_sim, reform_sim_id), ]: - ineq = calculate_uk_inequality(pe_sim) + ineq = calc_fn(pe_sim) ineq.run() record = Inequality( simulation_id=db_sim_id, @@ -260,51 +261,33 @@ def compute_inequality_module_uk( session.add(record) -def compute_budget_summary_module_uk( +def compute_budget_summary_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK budget summary aggregates.""" - from policyengine.core import Simulation as PESimulation - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - - uk_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - } - for var_name, entity in uk_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() + """Compute budget summary using config.budget_variables.""" + from policyengine.outputs.budget_summary import compute_budget_summary + + results = compute_budget_summary( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + variables=config.budget_variables, + ) + for item in results.outputs: record = BudgetSummary( baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, report_id=report_id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), + variable_name=item.variable_name, + entity=item.entity, + baseline_total=item.baseline_total, + reform_total=item.reform_total, + change=item.change, ) session.add(record) @@ -328,6 +311,11 @@ def compute_budget_summary_module_uk( session.add(record) +# --------------------------------------------------------------------------- +# Geographic / country-specific modules (unchanged structure) +# --------------------------------------------------------------------------- + + def compute_constituency_module( pe_baseline_sim, pe_reform_sim, @@ -335,7 +323,7 @@ def compute_constituency_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute UK parliamentary constituency impact.""" from policyengine.outputs.constituency_impact import ( @@ -391,7 +379,7 @@ def compute_local_authority_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute UK local authority impact.""" from policyengine.outputs.local_authority_impact import ( @@ -447,7 +435,7 @@ def compute_wealth_decile_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute UK wealth decile impact and intra-wealth-decile breakdown.""" from policyengine.core import Simulation as PESimulation @@ -513,212 +501,6 @@ def compute_wealth_decile_module( ) -# --------------------------------------------------------------------------- -# US-specific modules -# --------------------------------------------------------------------------- - - -def compute_program_statistics_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US program statistics.""" - from policyengine.core import Simulation as PESimulation - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) - - PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": {"entity": "person", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - except KeyError: - import logfire - - logfire.warning(f"Program variable not found, skipping: {prog_name}") - continue - record = ProgramStatistics( - baseline_simulation_id=baseline_sim_id, - reform_simulation_id=reform_sim_id, - report_id=report_id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(record) - - -def compute_poverty_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US poverty rates (overall, by age, gender, race).""" - from policyengine.outputs.poverty import ( - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - calculate_us_poverty_rates, - ) - - sim_pairs = [ - (pe_baseline_sim, baseline_sim_id), - (pe_reform_sim, reform_sim_id), - ] - - for calculator in [ - calculate_us_poverty_rates, - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - ]: - for pe_sim, db_sim_id in sim_pairs: - results = calculator(pe_sim) - for pov in results.outputs: - record = Poverty( - simulation_id=db_sim_id, - report_id=report_id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(record) - - -def compute_inequality_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US inequality metrics.""" - from policyengine.outputs.inequality import calculate_us_inequality - - for pe_sim, db_sim_id in [ - (pe_baseline_sim, baseline_sim_id), - (pe_reform_sim, reform_sim_id), - ]: - ineq = calculate_us_inequality(pe_sim) - ineq.run() - record = Inequality( - simulation_id=db_sim_id, - report_id=report_id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(record) - - -def compute_budget_summary_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US budget summary aggregates.""" - from policyengine.core import Simulation as PESimulation - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - - us_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - "household_state_income_tax": "tax_unit", - } - for var_name, entity in us_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - record = BudgetSummary( - baseline_simulation_id=baseline_sim_id, - reform_simulation_id=reform_sim_id, - report_id=report_id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(record) - - # Household count: raw sum of weights - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() - ) - record = BudgetSummary( - baseline_simulation_id=baseline_sim_id, - reform_simulation_id=reform_sim_id, - report_id=report_id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(record) - - def compute_congressional_district_module( pe_baseline_sim, pe_reform_sim, @@ -726,7 +508,7 @@ def compute_congressional_district_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute US congressional district impact.""" from policyengine.outputs.congressional_district_impact import ( @@ -761,37 +543,31 @@ def compute_congressional_district_module( # --------------------------------------------------------------------------- -# Dispatch tables: module name -> computation function +# Single dispatch table: module name -> computation function # --------------------------------------------------------------------------- -# Type alias for module computation functions -ModuleFunction = type(compute_decile_module) - -UK_MODULE_DISPATCH: dict[str, ModuleFunction] = { +MODULE_DISPATCH: dict[str, type(compute_decile_module)] = { "decile": compute_decile_module, - "program_statistics": compute_program_statistics_module_uk, - "poverty": compute_poverty_module_uk, - "inequality": compute_inequality_module_uk, - "budget_summary": compute_budget_summary_module_uk, "intra_decile": compute_intra_decile_module, + "program_statistics": compute_program_statistics_module, + "poverty": compute_poverty_module, + "inequality": compute_inequality_module, + "budget_summary": compute_budget_summary_module, "constituency": compute_constituency_module, "local_authority": compute_local_authority_module, "wealth_decile": compute_wealth_decile_module, -} - -US_MODULE_DISPATCH: dict[str, ModuleFunction] = { - "decile": compute_decile_module, - "program_statistics": compute_program_statistics_module_us, - "poverty": compute_poverty_module_us, - "inequality": compute_inequality_module_us, - "budget_summary": compute_budget_summary_module_us, - "intra_decile": compute_intra_decile_module, "congressional_district": compute_congressional_district_module, } +def get_dispatch_for_country(country_id: str) -> dict: + """Return the subset of MODULE_DISPATCH applicable to a country.""" + available = {m.name for m in MODULE_REGISTRY.values() if country_id in m.countries} + return {k: v for k, v in MODULE_DISPATCH.items() if k in available} + + def run_modules( - dispatch: dict[str, ModuleFunction], + country_id: str, modules: list[str] | None, pe_baseline_sim, pe_reform_sim, @@ -799,9 +575,13 @@ def run_modules( reform_sim_id: UUID, report_id: UUID, session: Session, - country_id: str = "", ) -> None: - """Run the requested modules (or all if modules is None).""" + """Run the requested modules (or all applicable) for a country. + + Resolves the country's dispatch table and CountryConfig automatically. + """ + dispatch = get_dispatch_for_country(country_id) + config = get_country_config(country_id) to_run = modules if modules is not None else list(dispatch.keys()) for mod_name in to_run: fn = dispatch.get(mod_name) @@ -813,5 +593,5 @@ def run_modules( reform_sim_id, report_id, session, - country_id=country_id, + config, ) diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py index b316296..d946043 100644 --- a/tests/test_computation_modules.py +++ b/tests/test_computation_modules.py @@ -1,156 +1,144 @@ """Tests for the composable computation module dispatch system.""" import inspect -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from uuid import uuid4 from policyengine_api.api import computation_modules as cm from policyengine_api.api.computation_modules import ( - UK_MODULE_DISPATCH, - US_MODULE_DISPATCH, + MODULE_DISPATCH, + get_dispatch_for_country, run_modules, ) from policyengine_api.api.module_registry import MODULE_REGISTRY -class TestDispatchTables: - """Tests for UK_MODULE_DISPATCH and US_MODULE_DISPATCH.""" +class TestModuleDispatch: + """Tests for the unified MODULE_DISPATCH table.""" - def test_uk_dispatch_keys_match_registry(self): - """Every UK dispatch key should be a valid module in the registry.""" - for key in UK_MODULE_DISPATCH: - assert key in MODULE_REGISTRY, f"UK dispatch key {key!r} not in registry" - - def test_us_dispatch_keys_match_registry(self): - """Every US dispatch key should be a valid module in the registry.""" - for key in US_MODULE_DISPATCH: - assert key in MODULE_REGISTRY, f"US dispatch key {key!r} not in registry" + def test_dispatch_keys_match_registry(self): + """Every dispatch key should be a valid module in the registry.""" + for key in MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"Dispatch key {key!r} not in registry" - def test_uk_dispatch_covers_uk_modules(self): - """UK dispatch should have an entry for every UK-applicable module.""" - uk_module_names = { - name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries - } - assert set(UK_MODULE_DISPATCH.keys()) == uk_module_names - - def test_us_dispatch_covers_us_modules(self): - """US dispatch should have an entry for every US-applicable module.""" - us_module_names = { - name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries - } - assert set(US_MODULE_DISPATCH.keys()) == us_module_names + def test_dispatch_covers_all_registry_modules(self): + """Dispatch should have an entry for every module in the registry.""" + assert set(MODULE_DISPATCH.keys()) == set(MODULE_REGISTRY.keys()) def test_all_dispatch_values_are_callable(self): - for fn in UK_MODULE_DISPATCH.values(): - assert callable(fn) - for fn in US_MODULE_DISPATCH.values(): + for fn in MODULE_DISPATCH.values(): assert callable(fn) - def test_uk_dispatch_has_9_entries(self): - assert len(UK_MODULE_DISPATCH) == 9 - - def test_us_dispatch_has_7_entries(self): - assert len(US_MODULE_DISPATCH) == 7 - - -class TestSharedModuleFunctions: - """Tests that shared modules reference the same function objects.""" - - def test_decile_function_shared_between_uk_and_us(self): - assert UK_MODULE_DISPATCH["decile"] is US_MODULE_DISPATCH["decile"] - assert UK_MODULE_DISPATCH["decile"] is cm.compute_decile_module - - def test_intra_decile_function_shared_between_uk_and_us(self): - assert UK_MODULE_DISPATCH["intra_decile"] is US_MODULE_DISPATCH["intra_decile"] - assert UK_MODULE_DISPATCH["intra_decile"] is cm.compute_intra_decile_module - - -class TestCountrySpecificFunctions: - """Tests that UK/US specific modules use the correct country-specific functions.""" - - def test_uk_program_statistics(self): - assert ( - UK_MODULE_DISPATCH["program_statistics"] - is cm.compute_program_statistics_module_uk - ) - - def test_us_program_statistics(self): - assert ( - US_MODULE_DISPATCH["program_statistics"] - is cm.compute_program_statistics_module_us - ) + def test_dispatch_has_10_entries(self): + assert len(MODULE_DISPATCH) == 10 - def test_uk_poverty(self): - assert UK_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_uk - def test_us_poverty(self): - assert US_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_us +class TestGetDispatchForCountry: + """Tests for country-filtered dispatch tables.""" - def test_uk_inequality(self): - assert UK_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_uk + def test_uk_dispatch_has_9_entries(self): + uk = get_dispatch_for_country("uk") + assert len(uk) == 9 - def test_us_inequality(self): - assert US_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_us + def test_us_dispatch_has_7_entries(self): + us = get_dispatch_for_country("us") + assert len(us) == 7 - def test_uk_budget_summary(self): - assert ( - UK_MODULE_DISPATCH["budget_summary"] is cm.compute_budget_summary_module_uk - ) + def test_uk_dispatch_keys_match_registry(self): + uk = get_dispatch_for_country("uk") + uk_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries + } + assert set(uk.keys()) == uk_module_names - def test_us_budget_summary(self): - assert ( - US_MODULE_DISPATCH["budget_summary"] is cm.compute_budget_summary_module_us - ) + def test_us_dispatch_keys_match_registry(self): + us = get_dispatch_for_country("us") + us_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries + } + assert set(us.keys()) == us_module_names def test_constituency_is_uk_only(self): - assert UK_MODULE_DISPATCH["constituency"] is cm.compute_constituency_module - assert "constituency" not in US_MODULE_DISPATCH + assert "constituency" in get_dispatch_for_country("uk") + assert "constituency" not in get_dispatch_for_country("us") def test_local_authority_is_uk_only(self): - assert ( - UK_MODULE_DISPATCH["local_authority"] is cm.compute_local_authority_module - ) - assert "local_authority" not in US_MODULE_DISPATCH + assert "local_authority" in get_dispatch_for_country("uk") + assert "local_authority" not in get_dispatch_for_country("us") def test_wealth_decile_is_uk_only(self): - assert UK_MODULE_DISPATCH["wealth_decile"] is cm.compute_wealth_decile_module - assert "wealth_decile" not in US_MODULE_DISPATCH + assert "wealth_decile" in get_dispatch_for_country("uk") + assert "wealth_decile" not in get_dispatch_for_country("us") def test_congressional_district_is_us_only(self): - assert ( - US_MODULE_DISPATCH["congressional_district"] - is cm.compute_congressional_district_module - ) - assert "congressional_district" not in UK_MODULE_DISPATCH + assert "congressional_district" in get_dispatch_for_country("us") + assert "congressional_district" not in get_dispatch_for_country("uk") + + +class TestUnifiedModuleFunctions: + """Tests that shared modules use the same function for both countries.""" + + def test_decile_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["decile"] is us["decile"] + assert uk["decile"] is cm.compute_decile_module + + def test_intra_decile_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["intra_decile"] is us["intra_decile"] + assert uk["intra_decile"] is cm.compute_intra_decile_module + + def test_program_statistics_is_shared(self): + """A single function handles both UK and US program statistics.""" + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["program_statistics"] is us["program_statistics"] + assert uk["program_statistics"] is cm.compute_program_statistics_module + + def test_poverty_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["poverty"] is us["poverty"] + assert uk["poverty"] is cm.compute_poverty_module + + def test_inequality_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["inequality"] is us["inequality"] + assert uk["inequality"] is cm.compute_inequality_module + + def test_budget_summary_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["budget_summary"] is us["budget_summary"] + assert uk["budget_summary"] is cm.compute_budget_summary_module class TestModuleFunctionSignatures: - """Tests that all module functions share the expected signature pattern. + """Tests that all module functions share the expected 7-param signature. - Modules use a common 7-param signature pattern: + Modules use a common signature pattern: (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, - report_id, session, **kwargs) -> None - - run_modules() passes country_id as a kwarg. Modules that need it (e.g. - compute_decile_module) accept it explicitly; others accept **_kwargs. + report_id, session, config) -> None """ - _BASE_PARAMS = [ + _EXPECTED_PARAMS = [ "pe_baseline_sim", "pe_reform_sim", "baseline_sim_id", "reform_sim_id", "report_id", "session", + "config", ] - # 7th param can be either explicit country_id or **_kwargs - _VALID_7TH_PARAMS = {"country_id", "_kwargs"} def _get_all_unique_functions(self): - """Collect all unique module functions from both dispatch tables.""" + """Collect all unique module functions from dispatch.""" seen = set() fns = [] - for fn in list(UK_MODULE_DISPATCH.values()) + list(US_MODULE_DISPATCH.values()): + for fn in MODULE_DISPATCH.values(): if id(fn) not in seen: seen.add(id(fn)) fns.append(fn) @@ -167,19 +155,13 @@ def test_all_functions_have_expected_param_names(self): for fn in self._get_all_unique_functions(): sig = inspect.signature(fn) param_names = list(sig.parameters.keys()) - # First 6 params must match exactly - assert param_names[:6] == self._BASE_PARAMS, ( - f"{fn.__name__} first 6 params {param_names[:6]} != {self._BASE_PARAMS}" - ) - # 7th param can be country_id or _kwargs - assert param_names[6] in self._VALID_7TH_PARAMS, ( - f"{fn.__name__} 7th param '{param_names[6]}' not in {self._VALID_7TH_PARAMS}" + assert param_names == self._EXPECTED_PARAMS, ( + f"{fn.__name__} params {param_names} != {self._EXPECTED_PARAMS}" ) def test_all_functions_return_none(self): for fn in self._get_all_unique_functions(): sig = inspect.signature(fn) - # `from __future__ import annotations` makes annotations strings assert sig.return_annotation in (None, "None", inspect.Parameter.empty), ( f"{fn.__name__} return annotation is {sig.return_annotation!r}, expected None" ) @@ -192,54 +174,84 @@ def _make_mock_dispatch(self, names): """Create a dispatch dict with mock functions.""" return {name: MagicMock(name=f"compute_{name}") for name in names} - def test_runs_all_when_modules_is_none(self): + def _mock_config(self): + """Create a mock CountryConfig.""" + config = MagicMock() + config.country_id = "us" + return config + + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_runs_all_when_modules_is_none(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a", "b", "c"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", None, "bl", "rf", ids[0], ids[1], ids[2], session) for fn in dispatch.values(): fn.assert_called_once_with( - "bl", "rf", ids[0], ids[1], ids[2], session, country_id="" + "bl", "rf", ids[0], ids[1], ids[2], session, config ) - def test_runs_only_requested_modules(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_runs_only_requested_modules(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a", "b", "c"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) dispatch["a"].assert_not_called() dispatch["b"].assert_called_once() dispatch["c"].assert_not_called() - def test_ignores_unknown_module_names(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_ignores_unknown_module_names(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] # Should not raise run_modules( - dispatch, ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session + "us", ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session ) dispatch["a"].assert_called_once() - def test_empty_modules_list_runs_nothing(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_empty_modules_list_runs_nothing(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a", "b"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, [], "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", [], "bl", "rf", ids[0], ids[1], ids[2], session) for fn in dispatch.values(): fn.assert_not_called() - def test_preserves_call_order(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_preserves_call_order(self, mock_get_config, mock_get_dispatch): """Modules should be called in the order they appear in the modules list.""" call_order = [] + config = self._mock_config() + mock_get_config.return_value = config def make_tracker(name): def fn(*args, **kwargs): @@ -248,17 +260,24 @@ def fn(*args, **kwargs): return fn dispatch = {name: make_tracker(name) for name in ["a", "b", "c"]} + mock_get_dispatch.return_value = dispatch ids = [uuid4() for _ in range(3)] run_modules( - dispatch, ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() + "us", ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() ) assert call_order == ["c", "a", "b"] - def test_none_modules_runs_all_in_dispatch_key_order(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_none_modules_runs_all_in_dispatch_key_order( + self, mock_get_config, mock_get_dispatch + ): """When modules is None, all dispatch entries run in dict-iteration order.""" call_order = [] + config = self._mock_config() + mock_get_config.return_value = config def make_tracker(name): def fn(*args, **kwargs): @@ -267,29 +286,40 @@ def fn(*args, **kwargs): return fn dispatch = {name: make_tracker(name) for name in ["x", "y", "z"]} + mock_get_dispatch.return_value = dispatch ids = [uuid4() for _ in range(3)] - run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) + run_modules("us", None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) assert call_order == ["x", "y", "z"] - def test_passes_all_args_correctly(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_passes_config_to_module_functions( + self, mock_get_config, mock_get_dispatch + ): mock_fn = MagicMock() + config = self._mock_config() + mock_get_config.return_value = config dispatch = {"test_mod": mock_fn} + mock_get_dispatch.return_value = dispatch session = MagicMock() bl, rf, b_id, r_id, rep_id = "baseline", "reform", uuid4(), uuid4(), uuid4() - run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session) + run_modules("us", ["test_mod"], bl, rf, b_id, r_id, rep_id, session) - mock_fn.assert_called_once_with( - bl, rf, b_id, r_id, rep_id, session, country_id="" - ) + mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session, config) - def test_duplicate_module_name_runs_twice(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_duplicate_module_name_runs_twice(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session) assert dispatch["a"].call_count == 2 From ac28f319c544f9a89a75d0546f764caa6cdd87c9 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Mar 2026 22:52:46 +0100 Subject: [PATCH 2/2] refactor: Replace inline Modal computation with run_modules() Both economy_comparison_uk() and economy_comparison_us() Modal functions now call run_modules() instead of duplicating all computation logic inline. This eliminates ~740 lines of duplicate code. The setup (DB loading, simulation creation, output dataset saving) and teardown (status marking, error handling) remain in the Modal functions, but all computation (deciles, programs, poverty, inequality, budget, geographic impacts) is delegated to the shared module dispatch system. Co-Authored-By: Claude Opus 4.6 --- src/policyengine_api/modal_app.py | 765 +----------------------------- 1 file changed, 24 insertions(+), 741 deletions(-) diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index d0f23f1..3e12bb4 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -1131,15 +1131,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: try: # Import models inline from policyengine_api.models import ( - BudgetSummary, - ConstituencyImpact, Dataset, - DecileImpact, - Inequality, - IntraDecileImpact, - LocalAuthorityImpact, - Poverty, - ProgramStatistics, Report, ReportStatus, Simulation, @@ -1179,20 +1171,10 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import ( - Aggregate as PEAggregate, - ) - from policyengine.outputs.aggregate import ( - AggregateType as PEAggregateType, - ) from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, ) - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) pe_model_version = uk_latest @@ -1337,413 +1319,19 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: session.refresh(reform_output_dataset) reform_sim.output_dataset_id = reform_output_dataset.id - # Calculate decile impacts - with logfire.span("calculate_decile_impacts"): - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - with logfire.span("calculate_program_statistics"): - PEProgrammeStats.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not in model, skip silently - - # Calculate poverty rates for baseline and reform - from policyengine.outputs.poverty import ( - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - calculate_uk_poverty_rates, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_uk_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_uk_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_uk_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - from policyengine.outputs.inequality import ( - calculate_uk_inequality, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_uk_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, - report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # UK budget variables — household-level aggregates - uk_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - } - PEAggregate.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - for var_name, entity in uk_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly - # from raw numpy values. Using Aggregate(SUM) on - # household_weight would compute sum(weight * weight) - # because MicroSeries.sum() applies weights automatically - # — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile, - ) - - intra_decile_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, + # Run all computation modules + from policyengine_api.api.computation_modules import run_modules + + with logfire.span("run_computation_modules"): + run_modules( + country_id="uk", + modules=None, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate constituency impact - from policyengine.outputs.constituency_impact import ( - compute_uk_constituency_impacts, - ) - - try: - from policyengine_core.tools.google_cloud import ( - download as gcs_download, - ) - - weight_matrix_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="parliamentary_constituency_weights.h5", - ) - constituency_csv_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="constituencies_2024.csv", - ) - constituency_impact = compute_uk_constituency_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=weight_matrix_path, - constituency_csv_path=constituency_csv_path, - ) - if constituency_impact.constituency_results: - for cr in constituency_impact.constituency_results: - record = ConstituencyImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - constituency_code=cr["constituency_code"], - constituency_name=cr["constituency_name"], - x=cr["x"], - y=cr["y"], - average_household_income_change=cr[ - "average_household_income_change" - ], - relative_household_income_change=cr[ - "relative_household_income_change" - ], - population=cr["population"], - ) - session.add(record) - except FileNotFoundError: - logfire.warning( - "Weight matrix not available, skipping constituency impact" - ) - - # Calculate local authority impact - from policyengine.outputs.local_authority_impact import ( - compute_uk_local_authority_impacts, - ) - - try: - from policyengine_core.tools.google_cloud import ( - download as gcs_download_la, - ) - - la_weight_matrix_path = gcs_download_la( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authority_weights.h5", - ) - la_csv_path = gcs_download_la( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authorities_2021.csv", - ) - la_impact = compute_uk_local_authority_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=la_weight_matrix_path, - local_authority_csv_path=la_csv_path, - ) - if la_impact.local_authority_results: - for lr in la_impact.local_authority_results: - record = LocalAuthorityImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - local_authority_code=lr["local_authority_code"], - local_authority_name=lr["local_authority_name"], - x=lr["x"], - y=lr["y"], - average_household_income_change=lr[ - "average_household_income_change" - ], - relative_household_income_change=lr[ - "relative_household_income_change" - ], - population=lr["population"], - ) - session.add(record) - except FileNotFoundError: - logfire.warning( - "Weight matrix not available, skipping local authority impact" - ) - - # Calculate wealth decile impact (UK only) - try: - from policyengine.outputs.decile_impact import ( - DecileImpact as PEDecileImpact, - ) - - PEDecileImpact.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - for decile_num in range(1, 11): - wealth_di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - decile=decile_num, - ) - wealth_di.run() - record = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable="household_wealth_decile", - entity="household", - decile=decile_num, - quantiles=10, - baseline_mean=wealth_di.baseline_mean, - reform_mean=wealth_di.reform_mean, - absolute_change=wealth_di.absolute_change, - relative_change=wealth_di.relative_change, - ) - session.add(record) - - # Calculate intra-wealth-decile impact - intra_wealth_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - ) - for r in intra_wealth_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile_type="wealth", - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - except KeyError: - logfire.warning( - "household_wealth_decile not available, skipping wealth decile impact" + session=session, ) # Mark simulations and report as completed @@ -1814,14 +1402,7 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: try: # Import models inline from policyengine_api.models import ( - BudgetSummary, - CongressionalDistrictImpact, Dataset, - DecileImpact, - Inequality, - IntraDecileImpact, - Poverty, - ProgramStatistics, Report, ReportStatus, Simulation, @@ -1854,20 +1435,10 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import ( - Aggregate as PEAggregate, - ) - from policyengine.outputs.aggregate import ( - AggregateType as PEAggregateType, - ) from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, ) - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) pe_model_version = us_latest @@ -2010,308 +1581,20 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: session.refresh(reform_output_dataset) reform_sim.output_dataset_id = reform_output_dataset.id - # Calculate decile impacts - with logfire.span("calculate_decile_impacts"): - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - income_variable="household_net_income", - ) - di.run() - - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - with logfire.span("calculate_program_statistics"): - PEProgramStats.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": { - "entity": "person", - "is_tax": True, - }, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not in model, skip silently - - # Calculate poverty rates for baseline and reform - from policyengine.outputs.poverty import ( - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - calculate_us_poverty_rates, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_us_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_us_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_us_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by race (US only) - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - race_poverty_results = calculate_us_poverty_by_race(pe_sim) - for pov in race_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - from policyengine.outputs.inequality import ( - calculate_us_inequality, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_us_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, + # Run all computation modules + from policyengine_api.api.computation_modules import run_modules + + with logfire.span("run_computation_modules"): + run_modules( + country_id="us", + modules=None, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # US budget variables — household-level plus state tax - us_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - "household_state_income_tax": "tax_unit", - } - PEAggregate.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - for var_name, entity in us_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, + session=session, ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly - # from raw numpy values. Using Aggregate(SUM) on - # household_weight would compute sum(weight * weight) - # because MicroSeries.sum() applies weights automatically - # — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile_us, - ) - - intra_decile_results_us = pe_compute_intra_decile_us( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results_us.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate congressional district impact - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) - - try: - district_impact = compute_us_congressional_district_impacts( - pe_baseline_sim, pe_reform_sim - ) - if district_impact.district_results: - for dr in district_impact.district_results: - record = CongressionalDistrictImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - district_geoid=dr["district_geoid"], - state_fips=dr["state_fips"], - district_number=dr["district_number"], - average_household_income_change=dr[ - "average_household_income_change" - ], - relative_household_income_change=dr[ - "relative_household_income_change" - ], - population=dr["population"], - ) - session.add(record) - except KeyError: - pass # congressional_district_geoid not in dataset # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED