diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c83c950..d36eeba 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -897,10 +897,10 @@ def build_dynamic(dynamic_id): pe_reform_sim.ensure() # Run computation modules - from policyengine_api.api.computation_modules import UK_MODULE_DISPATCH, run_modules + from policyengine_api.api.computation_modules import run_modules run_modules( - dispatch=UK_MODULE_DISPATCH, + country_id="uk", modules=modules, pe_baseline_sim=pe_baseline_sim, pe_reform_sim=pe_reform_sim, @@ -908,7 +908,6 @@ def build_dynamic(dynamic_id): reform_sim_id=reform_sim.id, report_id=report.id, session=session, - country_id="uk", ) # Mark completed @@ -1082,10 +1081,10 @@ def build_dynamic(dynamic_id): pe_reform_sim.ensure() # Run computation modules - from policyengine_api.api.computation_modules import US_MODULE_DISPATCH, run_modules + from policyengine_api.api.computation_modules import run_modules run_modules( - dispatch=US_MODULE_DISPATCH, + country_id="us", modules=modules, pe_baseline_sim=pe_baseline_sim, pe_reform_sim=pe_reform_sim, @@ -1093,7 +1092,6 @@ def build_dynamic(dynamic_id): reform_sim_id=reform_sim.id, report_id=report.id, session=session, - country_id="us", ) # Mark completed diff --git a/src/policyengine_api/api/computation_modules.py b/src/policyengine_api/api/computation_modules.py index 4c6c52d..60931b3 100644 --- a/src/policyengine_api/api/computation_modules.py +++ b/src/policyengine_api/api/computation_modules.py @@ -3,10 +3,10 @@ Each function computes a single module's results and writes DB records. They share a common signature pattern: (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, - report_id, session, **kwargs) -> None + report_id, session, config) -> None -run_modules() passes country_id as a kwarg. Modules that need it (e.g. -compute_decile_module) accept it explicitly; others accept **_kwargs. +run_modules() resolves the country's dispatch table and passes a +CountryConfig from the policyengine library to each module function. Used by _run_local_economy_comparison_uk/us to run modules selectively. """ @@ -17,6 +17,7 @@ from sqlmodel import Session +from policyengine_api.api.module_registry import MODULE_REGISTRY from policyengine_api.models import ( BudgetSummary, CongressionalDistrictImpact, @@ -30,14 +31,30 @@ ) # --------------------------------------------------------------------------- -# Shared modules (UK + US) +# Country configs — imported from the policyengine library # --------------------------------------------------------------------------- +# Lazy-loaded to avoid importing heavy policyengine modules at import time. +# Callers should use get_country_config() instead. +_COUNTRY_CONFIGS: dict | None = None -DECILE_INCOME_VARIABLE: dict[str, str] = { - "us": "household_net_income", - "uk": "equiv_household_net_income", -} + +def get_country_config(country_id: str): + """Return the CountryConfig for the given country_id ('uk' or 'us').""" + global _COUNTRY_CONFIGS + if _COUNTRY_CONFIGS is None: + from policyengine.outputs.country_config import UK_CONFIG, US_CONFIG + + _COUNTRY_CONFIGS = {"us": US_CONFIG, "uk": UK_CONFIG} + config = _COUNTRY_CONFIGS.get(country_id) + if config is None: + raise ValueError(f"No CountryConfig for country '{country_id}'") + return config + + +# --------------------------------------------------------------------------- +# Config-driven modules (shared UK + US) +# --------------------------------------------------------------------------- def compute_decile_module( @@ -47,26 +64,17 @@ def compute_decile_module( reform_sim_id: UUID, report_id: UUID, session: Session, - country_id: str = "", + config, ) -> None: - """Compute income decile impacts (1-10).""" - from policyengine.outputs import DecileImpact as PEDecileImpact - - if country_id not in DECILE_INCOME_VARIABLE: - raise ValueError( - f"No decile income variable configured for country '{country_id}'" - ) + """Compute income decile impacts (1-10) using config.income_variable.""" + from policyengine.outputs.decile_impact import compute_decile_impacts - income_variable = DECILE_INCOME_VARIABLE[country_id] - - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - income_variable=income_variable, - ) - di.run() + results = compute_decile_impacts( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + income_variable=config.income_variable, + ) + for di in results.outputs: record = DecileImpact( baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, @@ -93,7 +101,7 @@ def compute_intra_decile_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute intra-decile income change distribution (5 bands).""" from policyengine.outputs.intra_decile_impact import ( @@ -103,8 +111,7 @@ def compute_intra_decile_module( results = pe_compute_intra_decile( baseline_simulation=pe_baseline_sim, reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", + income_variable=config.income_variable, ) for r in results.outputs: record = IntraDecileImpact( @@ -121,61 +128,35 @@ def compute_intra_decile_module( session.add(record) -# --------------------------------------------------------------------------- -# UK-specific modules -# --------------------------------------------------------------------------- - - -def compute_program_statistics_module_uk( +def compute_program_statistics_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK programme statistics.""" - from policyengine.core import Simulation as PESimulation - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) + """Compute program statistics using config.programs.""" + from policyengine.outputs.program_statistics import compute_program_statistics - PEProgrammeStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - except KeyError: - import logfire - - logfire.warning(f"Program variable not found, skipping: {prog_name}") - continue + results = compute_program_statistics( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + programs=config.programs, + ) + for ps in results.outputs: + # Detect the name field (US: program_name, UK: programme_name) + prog_name = getattr(ps, "program_name", None) or getattr( + ps, "programme_name", None + ) record = ProgramStatistics( baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, report_id=report_id, program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], + entity=ps.entity, + is_tax=ps.is_tax, baseline_total=ps.baseline_total, reform_total=ps.reform_total, change=ps.change, @@ -187,32 +168,49 @@ def compute_program_statistics_module_uk( session.add(record) -def compute_poverty_module_uk( +def compute_poverty_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK poverty rates (overall, by age, by gender).""" - from policyengine.outputs.poverty import ( - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - calculate_uk_poverty_rates, - ) + """Compute poverty rates using config.country_id to select calculators.""" + if config.country_id == "uk": + from policyengine.outputs.poverty import ( + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + calculate_uk_poverty_rates, + ) + + calculators = [ + calculate_uk_poverty_rates, + calculate_uk_poverty_by_age, + calculate_uk_poverty_by_gender, + ] + else: + from policyengine.outputs.poverty import ( + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + calculate_us_poverty_rates, + ) + + calculators = [ + calculate_us_poverty_rates, + calculate_us_poverty_by_age, + calculate_us_poverty_by_gender, + calculate_us_poverty_by_race, + ] sim_pairs = [ (pe_baseline_sim, baseline_sim_id), (pe_reform_sim, reform_sim_id), ] - for calculator in [ - calculate_uk_poverty_rates, - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - ]: + for calculator in calculators: for pe_sim, db_sim_id in sim_pairs: results = calculator(pe_sim) for pov in results.outputs: @@ -229,23 +227,26 @@ def compute_poverty_module_uk( session.add(record) -def compute_inequality_module_uk( +def compute_inequality_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK inequality metrics.""" - from policyengine.outputs.inequality import calculate_uk_inequality + """Compute inequality metrics using config.country_id to select calculator.""" + if config.country_id == "uk": + from policyengine.outputs.inequality import calculate_uk_inequality as calc_fn + else: + from policyengine.outputs.inequality import calculate_us_inequality as calc_fn for pe_sim, db_sim_id in [ (pe_baseline_sim, baseline_sim_id), (pe_reform_sim, reform_sim_id), ]: - ineq = calculate_uk_inequality(pe_sim) + ineq = calc_fn(pe_sim) ineq.run() record = Inequality( simulation_id=db_sim_id, @@ -260,51 +261,33 @@ def compute_inequality_module_uk( session.add(record) -def compute_budget_summary_module_uk( +def compute_budget_summary_module( pe_baseline_sim, pe_reform_sim, baseline_sim_id: UUID, reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: - """Compute UK budget summary aggregates.""" - from policyengine.core import Simulation as PESimulation - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - - uk_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - } - for var_name, entity in uk_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() + """Compute budget summary using config.budget_variables.""" + from policyengine.outputs.budget_summary import compute_budget_summary + + results = compute_budget_summary( + baseline_simulation=pe_baseline_sim, + reform_simulation=pe_reform_sim, + variables=config.budget_variables, + ) + for item in results.outputs: record = BudgetSummary( baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, report_id=report_id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), + variable_name=item.variable_name, + entity=item.entity, + baseline_total=item.baseline_total, + reform_total=item.reform_total, + change=item.change, ) session.add(record) @@ -328,6 +311,11 @@ def compute_budget_summary_module_uk( session.add(record) +# --------------------------------------------------------------------------- +# Geographic / country-specific modules (unchanged structure) +# --------------------------------------------------------------------------- + + def compute_constituency_module( pe_baseline_sim, pe_reform_sim, @@ -335,7 +323,7 @@ def compute_constituency_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute UK parliamentary constituency impact.""" from policyengine.outputs.constituency_impact import ( @@ -391,7 +379,7 @@ def compute_local_authority_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute UK local authority impact.""" from policyengine.outputs.local_authority_impact import ( @@ -447,7 +435,7 @@ def compute_wealth_decile_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute UK wealth decile impact and intra-wealth-decile breakdown.""" from policyengine.core import Simulation as PESimulation @@ -513,212 +501,6 @@ def compute_wealth_decile_module( ) -# --------------------------------------------------------------------------- -# US-specific modules -# --------------------------------------------------------------------------- - - -def compute_program_statistics_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US program statistics.""" - from policyengine.core import Simulation as PESimulation - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) - - PEProgramStats.model_rebuild(_types_namespace={"Simulation": PESimulation}) - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": {"entity": "person", "is_tax": True}, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - except KeyError: - import logfire - - logfire.warning(f"Program variable not found, skipping: {prog_name}") - continue - record = ProgramStatistics( - baseline_simulation_id=baseline_sim_id, - reform_simulation_id=reform_sim_id, - report_id=report_id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(record) - - -def compute_poverty_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US poverty rates (overall, by age, gender, race).""" - from policyengine.outputs.poverty import ( - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - calculate_us_poverty_rates, - ) - - sim_pairs = [ - (pe_baseline_sim, baseline_sim_id), - (pe_reform_sim, reform_sim_id), - ] - - for calculator in [ - calculate_us_poverty_rates, - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - ]: - for pe_sim, db_sim_id in sim_pairs: - results = calculator(pe_sim) - for pov in results.outputs: - record = Poverty( - simulation_id=db_sim_id, - report_id=report_id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(record) - - -def compute_inequality_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US inequality metrics.""" - from policyengine.outputs.inequality import calculate_us_inequality - - for pe_sim, db_sim_id in [ - (pe_baseline_sim, baseline_sim_id), - (pe_reform_sim, reform_sim_id), - ]: - ineq = calculate_us_inequality(pe_sim) - ineq.run() - record = Inequality( - simulation_id=db_sim_id, - report_id=report_id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(record) - - -def compute_budget_summary_module_us( - pe_baseline_sim, - pe_reform_sim, - baseline_sim_id: UUID, - reform_sim_id: UUID, - report_id: UUID, - session: Session, - **_kwargs, -) -> None: - """Compute US budget summary aggregates.""" - from policyengine.core import Simulation as PESimulation - from policyengine.outputs.aggregate import Aggregate as PEAggregate - from policyengine.outputs.aggregate import AggregateType as PEAggregateType - - PEAggregate.model_rebuild(_types_namespace={"Simulation": PESimulation}) - - us_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - "household_state_income_tax": "tax_unit", - } - for var_name, entity in us_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - record = BudgetSummary( - baseline_simulation_id=baseline_sim_id, - reform_simulation_id=reform_sim_id, - report_id=report_id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(record) - - # Household count: raw sum of weights - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household["household_weight"].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household["household_weight"].values.sum() - ) - record = BudgetSummary( - baseline_simulation_id=baseline_sim_id, - reform_simulation_id=reform_sim_id, - report_id=report_id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(record) - - def compute_congressional_district_module( pe_baseline_sim, pe_reform_sim, @@ -726,7 +508,7 @@ def compute_congressional_district_module( reform_sim_id: UUID, report_id: UUID, session: Session, - **_kwargs, + config, ) -> None: """Compute US congressional district impact.""" from policyengine.outputs.congressional_district_impact import ( @@ -761,37 +543,31 @@ def compute_congressional_district_module( # --------------------------------------------------------------------------- -# Dispatch tables: module name -> computation function +# Single dispatch table: module name -> computation function # --------------------------------------------------------------------------- -# Type alias for module computation functions -ModuleFunction = type(compute_decile_module) - -UK_MODULE_DISPATCH: dict[str, ModuleFunction] = { +MODULE_DISPATCH: dict[str, type(compute_decile_module)] = { "decile": compute_decile_module, - "program_statistics": compute_program_statistics_module_uk, - "poverty": compute_poverty_module_uk, - "inequality": compute_inequality_module_uk, - "budget_summary": compute_budget_summary_module_uk, "intra_decile": compute_intra_decile_module, + "program_statistics": compute_program_statistics_module, + "poverty": compute_poverty_module, + "inequality": compute_inequality_module, + "budget_summary": compute_budget_summary_module, "constituency": compute_constituency_module, "local_authority": compute_local_authority_module, "wealth_decile": compute_wealth_decile_module, -} - -US_MODULE_DISPATCH: dict[str, ModuleFunction] = { - "decile": compute_decile_module, - "program_statistics": compute_program_statistics_module_us, - "poverty": compute_poverty_module_us, - "inequality": compute_inequality_module_us, - "budget_summary": compute_budget_summary_module_us, - "intra_decile": compute_intra_decile_module, "congressional_district": compute_congressional_district_module, } +def get_dispatch_for_country(country_id: str) -> dict: + """Return the subset of MODULE_DISPATCH applicable to a country.""" + available = {m.name for m in MODULE_REGISTRY.values() if country_id in m.countries} + return {k: v for k, v in MODULE_DISPATCH.items() if k in available} + + def run_modules( - dispatch: dict[str, ModuleFunction], + country_id: str, modules: list[str] | None, pe_baseline_sim, pe_reform_sim, @@ -799,9 +575,13 @@ def run_modules( reform_sim_id: UUID, report_id: UUID, session: Session, - country_id: str = "", ) -> None: - """Run the requested modules (or all if modules is None).""" + """Run the requested modules (or all applicable) for a country. + + Resolves the country's dispatch table and CountryConfig automatically. + """ + dispatch = get_dispatch_for_country(country_id) + config = get_country_config(country_id) to_run = modules if modules is not None else list(dispatch.keys()) for mod_name in to_run: fn = dispatch.get(mod_name) @@ -813,5 +593,5 @@ def run_modules( reform_sim_id, report_id, session, - country_id=country_id, + config, ) diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index d0f23f1..3e12bb4 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -1131,15 +1131,7 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: try: # Import models inline from policyengine_api.models import ( - BudgetSummary, - ConstituencyImpact, Dataset, - DecileImpact, - Inequality, - IntraDecileImpact, - LocalAuthorityImpact, - Poverty, - ProgramStatistics, Report, ReportStatus, Simulation, @@ -1179,20 +1171,10 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import ( - Aggregate as PEAggregate, - ) - from policyengine.outputs.aggregate import ( - AggregateType as PEAggregateType, - ) from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, ) - from policyengine.tax_benefit_models.uk.outputs import ( - ProgrammeStatistics as PEProgrammeStats, - ) pe_model_version = uk_latest @@ -1337,413 +1319,19 @@ def economy_comparison_uk(job_id: str, traceparent: str | None = None) -> None: session.refresh(reform_output_dataset) reform_sim.output_dataset_id = reform_output_dataset.id - # Calculate decile impacts - with logfire.span("calculate_decile_impacts"): - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - ) - di.run() - - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - with logfire.span("calculate_program_statistics"): - PEProgrammeStats.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - - programmes = { - "income_tax": {"entity": "person", "is_tax": True}, - "national_insurance": {"entity": "person", "is_tax": True}, - "vat": {"entity": "household", "is_tax": True}, - "council_tax": {"entity": "household", "is_tax": True}, - "universal_credit": {"entity": "person", "is_tax": False}, - "child_benefit": {"entity": "person", "is_tax": False}, - "pension_credit": {"entity": "person", "is_tax": False}, - "income_support": {"entity": "person", "is_tax": False}, - "working_tax_credit": {"entity": "person", "is_tax": False}, - "child_tax_credit": {"entity": "person", "is_tax": False}, - } - - for prog_name, prog_info in programmes.items(): - try: - ps = PEProgrammeStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - programme_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not in model, skip silently - - # Calculate poverty rates for baseline and reform - from policyengine.outputs.poverty import ( - calculate_uk_poverty_by_age, - calculate_uk_poverty_by_gender, - calculate_uk_poverty_rates, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_uk_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_uk_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_uk_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - from policyengine.outputs.inequality import ( - calculate_uk_inequality, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_uk_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, - report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # UK budget variables — household-level aggregates - uk_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - } - PEAggregate.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - for var_name, entity in uk_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly - # from raw numpy values. Using Aggregate(SUM) on - # household_weight would compute sum(weight * weight) - # because MicroSeries.sum() applies weights automatically - # — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile, - ) - - intra_decile_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, + # Run all computation modules + from policyengine_api.api.computation_modules import run_modules + + with logfire.span("run_computation_modules"): + run_modules( + country_id="uk", + modules=None, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate constituency impact - from policyengine.outputs.constituency_impact import ( - compute_uk_constituency_impacts, - ) - - try: - from policyengine_core.tools.google_cloud import ( - download as gcs_download, - ) - - weight_matrix_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="parliamentary_constituency_weights.h5", - ) - constituency_csv_path = gcs_download( - gcs_bucket="policyengine-uk-data-private", - gcs_key="constituencies_2024.csv", - ) - constituency_impact = compute_uk_constituency_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=weight_matrix_path, - constituency_csv_path=constituency_csv_path, - ) - if constituency_impact.constituency_results: - for cr in constituency_impact.constituency_results: - record = ConstituencyImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - constituency_code=cr["constituency_code"], - constituency_name=cr["constituency_name"], - x=cr["x"], - y=cr["y"], - average_household_income_change=cr[ - "average_household_income_change" - ], - relative_household_income_change=cr[ - "relative_household_income_change" - ], - population=cr["population"], - ) - session.add(record) - except FileNotFoundError: - logfire.warning( - "Weight matrix not available, skipping constituency impact" - ) - - # Calculate local authority impact - from policyengine.outputs.local_authority_impact import ( - compute_uk_local_authority_impacts, - ) - - try: - from policyengine_core.tools.google_cloud import ( - download as gcs_download_la, - ) - - la_weight_matrix_path = gcs_download_la( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authority_weights.h5", - ) - la_csv_path = gcs_download_la( - gcs_bucket="policyengine-uk-data-private", - gcs_key="local_authorities_2021.csv", - ) - la_impact = compute_uk_local_authority_impacts( - pe_baseline_sim, - pe_reform_sim, - weight_matrix_path=la_weight_matrix_path, - local_authority_csv_path=la_csv_path, - ) - if la_impact.local_authority_results: - for lr in la_impact.local_authority_results: - record = LocalAuthorityImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - local_authority_code=lr["local_authority_code"], - local_authority_name=lr["local_authority_name"], - x=lr["x"], - y=lr["y"], - average_household_income_change=lr[ - "average_household_income_change" - ], - relative_household_income_change=lr[ - "relative_household_income_change" - ], - population=lr["population"], - ) - session.add(record) - except FileNotFoundError: - logfire.warning( - "Weight matrix not available, skipping local authority impact" - ) - - # Calculate wealth decile impact (UK only) - try: - from policyengine.outputs.decile_impact import ( - DecileImpact as PEDecileImpact, - ) - - PEDecileImpact.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - for decile_num in range(1, 11): - wealth_di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - decile=decile_num, - ) - wealth_di.run() - record = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable="household_wealth_decile", - entity="household", - decile=decile_num, - quantiles=10, - baseline_mean=wealth_di.baseline_mean, - reform_mean=wealth_di.reform_mean, - absolute_change=wealth_di.absolute_change, - relative_change=wealth_di.relative_change, - ) - session.add(record) - - # Calculate intra-wealth-decile impact - intra_wealth_results = pe_compute_intra_decile( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - decile_variable="household_wealth_decile", - entity="household", - ) - for r in intra_wealth_results.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile_type="wealth", - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - except KeyError: - logfire.warning( - "household_wealth_decile not available, skipping wealth decile impact" + session=session, ) # Mark simulations and report as completed @@ -1814,14 +1402,7 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: try: # Import models inline from policyengine_api.models import ( - BudgetSummary, - CongressionalDistrictImpact, Dataset, - DecileImpact, - Inequality, - IntraDecileImpact, - Poverty, - ProgramStatistics, Report, ReportStatus, Simulation, @@ -1854,20 +1435,10 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: # Import policyengine from policyengine.core import Simulation as PESimulation - from policyengine.outputs import DecileImpact as PEDecileImpact - from policyengine.outputs.aggregate import ( - Aggregate as PEAggregate, - ) - from policyengine.outputs.aggregate import ( - AggregateType as PEAggregateType, - ) from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, ) - from policyengine.tax_benefit_models.us.outputs import ( - ProgramStatistics as PEProgramStats, - ) pe_model_version = us_latest @@ -2010,308 +1581,20 @@ def economy_comparison_us(job_id: str, traceparent: str | None = None) -> None: session.refresh(reform_output_dataset) reform_sim.output_dataset_id = reform_output_dataset.id - # Calculate decile impacts - with logfire.span("calculate_decile_impacts"): - for decile_num in range(1, 11): - di = PEDecileImpact( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - decile=decile_num, - income_variable="household_net_income", - ) - di.run() - - decile_impact = DecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - income_variable=di.income_variable, - entity=di.entity, - decile=di.decile, - quantiles=di.quantiles, - baseline_mean=di.baseline_mean, - reform_mean=di.reform_mean, - absolute_change=di.absolute_change, - relative_change=di.relative_change, - count_better_off=di.count_better_off, - count_worse_off=di.count_worse_off, - count_no_change=di.count_no_change, - ) - session.add(decile_impact) - - # Calculate program statistics - with logfire.span("calculate_program_statistics"): - PEProgramStats.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - - programs = { - "income_tax": {"entity": "tax_unit", "is_tax": True}, - "employee_payroll_tax": { - "entity": "person", - "is_tax": True, - }, - "snap": {"entity": "spm_unit", "is_tax": False}, - "tanf": {"entity": "spm_unit", "is_tax": False}, - "ssi": {"entity": "spm_unit", "is_tax": False}, - "social_security": {"entity": "person", "is_tax": False}, - } - - for prog_name, prog_info in programs.items(): - try: - ps = PEProgramStats( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - ) - ps.run() - program_stat = ProgramStatistics( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - program_name=prog_name, - entity=prog_info["entity"], - is_tax=prog_info["is_tax"], - baseline_total=ps.baseline_total, - reform_total=ps.reform_total, - change=ps.change, - baseline_count=ps.baseline_count, - reform_count=ps.reform_count, - winners=ps.winners, - losers=ps.losers, - ) - session.add(program_stat) - except KeyError: - pass # Variable not in model, skip silently - - # Calculate poverty rates for baseline and reform - from policyengine.outputs.poverty import ( - calculate_us_poverty_by_age, - calculate_us_poverty_by_gender, - calculate_us_poverty_by_race, - calculate_us_poverty_rates, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - poverty_results = calculate_us_poverty_rates(pe_sim) - for pov in poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by age group - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - age_poverty_results = calculate_us_poverty_by_age(pe_sim) - for pov in age_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by gender - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - gender_poverty_results = calculate_us_poverty_by_gender(pe_sim) - for pov in gender_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate poverty rates by race (US only) - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - race_poverty_results = calculate_us_poverty_by_race(pe_sim) - for pov in race_poverty_results.outputs: - poverty_record = Poverty( - simulation_id=db_sim.id, - report_id=report.id, - poverty_type=pov.poverty_type, - entity=pov.entity, - filter_variable=pov.filter_variable, - headcount=pov.headcount, - total_population=pov.total_population, - rate=pov.rate, - ) - session.add(poverty_record) - - # Calculate inequality for baseline and reform - from policyengine.outputs.inequality import ( - calculate_us_inequality, - ) - - for pe_sim, db_sim in [ - (pe_baseline_sim, baseline_sim), - (pe_reform_sim, reform_sim), - ]: - ineq = calculate_us_inequality(pe_sim) - ineq.run() - inequality_record = Inequality( - simulation_id=db_sim.id, + # Run all computation modules + from policyengine_api.api.computation_modules import run_modules + + with logfire.span("run_computation_modules"): + run_modules( + country_id="us", + modules=None, + pe_baseline_sim=pe_baseline_sim, + pe_reform_sim=pe_reform_sim, + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id, report_id=report.id, - income_variable=ineq.income_variable, - entity=ineq.entity, - gini=ineq.gini, - top_10_share=ineq.top_10_share, - top_1_share=ineq.top_1_share, - bottom_50_share=ineq.bottom_50_share, - ) - session.add(inequality_record) - - # Calculate budget summary aggregates - # US budget variables — household-level plus state tax - us_budget_variables = { - "household_tax": "household", - "household_benefits": "household", - "household_net_income": "household", - "household_state_income_tax": "tax_unit", - } - PEAggregate.model_rebuild( - _types_namespace={"Simulation": PESimulation} - ) - for var_name, entity in us_budget_variables.items(): - baseline_agg = PEAggregate( - simulation=pe_baseline_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, - ) - baseline_agg.run() - reform_agg = PEAggregate( - simulation=pe_reform_sim, - variable=var_name, - aggregate_type=PEAggregateType.SUM, - entity=entity, + session=session, ) - reform_agg.run() - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name=var_name, - entity=entity, - baseline_total=float(baseline_agg.result), - reform_total=float(reform_agg.result), - change=float(reform_agg.result - baseline_agg.result), - ) - session.add(budget_record) - - # Household count: bypass Aggregate and compute directly - # from raw numpy values. Using Aggregate(SUM) on - # household_weight would compute sum(weight * weight) - # because MicroSeries.sum() applies weights automatically - # — it's unclear whether Aggregate can be used correctly - # for summing the weight column itself. - baseline_hh_count = float( - pe_baseline_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - reform_hh_count = float( - pe_reform_sim.output_dataset.data.household[ - "household_weight" - ].values.sum() - ) - budget_record = BudgetSummary( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - variable_name="household_count_total", - entity="household", - baseline_total=baseline_hh_count, - reform_total=reform_hh_count, - change=reform_hh_count - baseline_hh_count, - ) - session.add(budget_record) - - # Calculate intra-decile impact - from policyengine.outputs.intra_decile_impact import ( - compute_intra_decile_impacts as pe_compute_intra_decile_us, - ) - - intra_decile_results_us = pe_compute_intra_decile_us( - baseline_simulation=pe_baseline_sim, - reform_simulation=pe_reform_sim, - income_variable="household_net_income", - entity="household", - ) - for r in intra_decile_results_us.outputs: - record = IntraDecileImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - decile=r.decile, - lose_more_than_5pct=r.lose_more_than_5pct, - lose_less_than_5pct=r.lose_less_than_5pct, - no_change=r.no_change, - gain_less_than_5pct=r.gain_less_than_5pct, - gain_more_than_5pct=r.gain_more_than_5pct, - ) - session.add(record) - - # Calculate congressional district impact - from policyengine.outputs.congressional_district_impact import ( - compute_us_congressional_district_impacts, - ) - - try: - district_impact = compute_us_congressional_district_impacts( - pe_baseline_sim, pe_reform_sim - ) - if district_impact.district_results: - for dr in district_impact.district_results: - record = CongressionalDistrictImpact( - baseline_simulation_id=baseline_sim.id, - reform_simulation_id=reform_sim.id, - report_id=report.id, - district_geoid=dr["district_geoid"], - state_fips=dr["state_fips"], - district_number=dr["district_number"], - average_household_income_change=dr[ - "average_household_income_change" - ], - relative_household_income_change=dr[ - "relative_household_income_change" - ], - population=dr["population"], - ) - session.add(record) - except KeyError: - pass # congressional_district_geoid not in dataset # Mark simulations and report as completed baseline_sim.status = SimulationStatus.COMPLETED diff --git a/tests/test_computation_modules.py b/tests/test_computation_modules.py index b316296..d946043 100644 --- a/tests/test_computation_modules.py +++ b/tests/test_computation_modules.py @@ -1,156 +1,144 @@ """Tests for the composable computation module dispatch system.""" import inspect -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from uuid import uuid4 from policyengine_api.api import computation_modules as cm from policyengine_api.api.computation_modules import ( - UK_MODULE_DISPATCH, - US_MODULE_DISPATCH, + MODULE_DISPATCH, + get_dispatch_for_country, run_modules, ) from policyengine_api.api.module_registry import MODULE_REGISTRY -class TestDispatchTables: - """Tests for UK_MODULE_DISPATCH and US_MODULE_DISPATCH.""" +class TestModuleDispatch: + """Tests for the unified MODULE_DISPATCH table.""" - def test_uk_dispatch_keys_match_registry(self): - """Every UK dispatch key should be a valid module in the registry.""" - for key in UK_MODULE_DISPATCH: - assert key in MODULE_REGISTRY, f"UK dispatch key {key!r} not in registry" - - def test_us_dispatch_keys_match_registry(self): - """Every US dispatch key should be a valid module in the registry.""" - for key in US_MODULE_DISPATCH: - assert key in MODULE_REGISTRY, f"US dispatch key {key!r} not in registry" + def test_dispatch_keys_match_registry(self): + """Every dispatch key should be a valid module in the registry.""" + for key in MODULE_DISPATCH: + assert key in MODULE_REGISTRY, f"Dispatch key {key!r} not in registry" - def test_uk_dispatch_covers_uk_modules(self): - """UK dispatch should have an entry for every UK-applicable module.""" - uk_module_names = { - name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries - } - assert set(UK_MODULE_DISPATCH.keys()) == uk_module_names - - def test_us_dispatch_covers_us_modules(self): - """US dispatch should have an entry for every US-applicable module.""" - us_module_names = { - name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries - } - assert set(US_MODULE_DISPATCH.keys()) == us_module_names + def test_dispatch_covers_all_registry_modules(self): + """Dispatch should have an entry for every module in the registry.""" + assert set(MODULE_DISPATCH.keys()) == set(MODULE_REGISTRY.keys()) def test_all_dispatch_values_are_callable(self): - for fn in UK_MODULE_DISPATCH.values(): - assert callable(fn) - for fn in US_MODULE_DISPATCH.values(): + for fn in MODULE_DISPATCH.values(): assert callable(fn) - def test_uk_dispatch_has_9_entries(self): - assert len(UK_MODULE_DISPATCH) == 9 - - def test_us_dispatch_has_7_entries(self): - assert len(US_MODULE_DISPATCH) == 7 - - -class TestSharedModuleFunctions: - """Tests that shared modules reference the same function objects.""" - - def test_decile_function_shared_between_uk_and_us(self): - assert UK_MODULE_DISPATCH["decile"] is US_MODULE_DISPATCH["decile"] - assert UK_MODULE_DISPATCH["decile"] is cm.compute_decile_module - - def test_intra_decile_function_shared_between_uk_and_us(self): - assert UK_MODULE_DISPATCH["intra_decile"] is US_MODULE_DISPATCH["intra_decile"] - assert UK_MODULE_DISPATCH["intra_decile"] is cm.compute_intra_decile_module - - -class TestCountrySpecificFunctions: - """Tests that UK/US specific modules use the correct country-specific functions.""" - - def test_uk_program_statistics(self): - assert ( - UK_MODULE_DISPATCH["program_statistics"] - is cm.compute_program_statistics_module_uk - ) - - def test_us_program_statistics(self): - assert ( - US_MODULE_DISPATCH["program_statistics"] - is cm.compute_program_statistics_module_us - ) + def test_dispatch_has_10_entries(self): + assert len(MODULE_DISPATCH) == 10 - def test_uk_poverty(self): - assert UK_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_uk - def test_us_poverty(self): - assert US_MODULE_DISPATCH["poverty"] is cm.compute_poverty_module_us +class TestGetDispatchForCountry: + """Tests for country-filtered dispatch tables.""" - def test_uk_inequality(self): - assert UK_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_uk + def test_uk_dispatch_has_9_entries(self): + uk = get_dispatch_for_country("uk") + assert len(uk) == 9 - def test_us_inequality(self): - assert US_MODULE_DISPATCH["inequality"] is cm.compute_inequality_module_us + def test_us_dispatch_has_7_entries(self): + us = get_dispatch_for_country("us") + assert len(us) == 7 - def test_uk_budget_summary(self): - assert ( - UK_MODULE_DISPATCH["budget_summary"] is cm.compute_budget_summary_module_uk - ) + def test_uk_dispatch_keys_match_registry(self): + uk = get_dispatch_for_country("uk") + uk_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "uk" in mod.countries + } + assert set(uk.keys()) == uk_module_names - def test_us_budget_summary(self): - assert ( - US_MODULE_DISPATCH["budget_summary"] is cm.compute_budget_summary_module_us - ) + def test_us_dispatch_keys_match_registry(self): + us = get_dispatch_for_country("us") + us_module_names = { + name for name, mod in MODULE_REGISTRY.items() if "us" in mod.countries + } + assert set(us.keys()) == us_module_names def test_constituency_is_uk_only(self): - assert UK_MODULE_DISPATCH["constituency"] is cm.compute_constituency_module - assert "constituency" not in US_MODULE_DISPATCH + assert "constituency" in get_dispatch_for_country("uk") + assert "constituency" not in get_dispatch_for_country("us") def test_local_authority_is_uk_only(self): - assert ( - UK_MODULE_DISPATCH["local_authority"] is cm.compute_local_authority_module - ) - assert "local_authority" not in US_MODULE_DISPATCH + assert "local_authority" in get_dispatch_for_country("uk") + assert "local_authority" not in get_dispatch_for_country("us") def test_wealth_decile_is_uk_only(self): - assert UK_MODULE_DISPATCH["wealth_decile"] is cm.compute_wealth_decile_module - assert "wealth_decile" not in US_MODULE_DISPATCH + assert "wealth_decile" in get_dispatch_for_country("uk") + assert "wealth_decile" not in get_dispatch_for_country("us") def test_congressional_district_is_us_only(self): - assert ( - US_MODULE_DISPATCH["congressional_district"] - is cm.compute_congressional_district_module - ) - assert "congressional_district" not in UK_MODULE_DISPATCH + assert "congressional_district" in get_dispatch_for_country("us") + assert "congressional_district" not in get_dispatch_for_country("uk") + + +class TestUnifiedModuleFunctions: + """Tests that shared modules use the same function for both countries.""" + + def test_decile_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["decile"] is us["decile"] + assert uk["decile"] is cm.compute_decile_module + + def test_intra_decile_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["intra_decile"] is us["intra_decile"] + assert uk["intra_decile"] is cm.compute_intra_decile_module + + def test_program_statistics_is_shared(self): + """A single function handles both UK and US program statistics.""" + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["program_statistics"] is us["program_statistics"] + assert uk["program_statistics"] is cm.compute_program_statistics_module + + def test_poverty_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["poverty"] is us["poverty"] + assert uk["poverty"] is cm.compute_poverty_module + + def test_inequality_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["inequality"] is us["inequality"] + assert uk["inequality"] is cm.compute_inequality_module + + def test_budget_summary_is_shared(self): + uk = get_dispatch_for_country("uk") + us = get_dispatch_for_country("us") + assert uk["budget_summary"] is us["budget_summary"] + assert uk["budget_summary"] is cm.compute_budget_summary_module class TestModuleFunctionSignatures: - """Tests that all module functions share the expected signature pattern. + """Tests that all module functions share the expected 7-param signature. - Modules use a common 7-param signature pattern: + Modules use a common signature pattern: (pe_baseline_sim, pe_reform_sim, baseline_sim_id, reform_sim_id, - report_id, session, **kwargs) -> None - - run_modules() passes country_id as a kwarg. Modules that need it (e.g. - compute_decile_module) accept it explicitly; others accept **_kwargs. + report_id, session, config) -> None """ - _BASE_PARAMS = [ + _EXPECTED_PARAMS = [ "pe_baseline_sim", "pe_reform_sim", "baseline_sim_id", "reform_sim_id", "report_id", "session", + "config", ] - # 7th param can be either explicit country_id or **_kwargs - _VALID_7TH_PARAMS = {"country_id", "_kwargs"} def _get_all_unique_functions(self): - """Collect all unique module functions from both dispatch tables.""" + """Collect all unique module functions from dispatch.""" seen = set() fns = [] - for fn in list(UK_MODULE_DISPATCH.values()) + list(US_MODULE_DISPATCH.values()): + for fn in MODULE_DISPATCH.values(): if id(fn) not in seen: seen.add(id(fn)) fns.append(fn) @@ -167,19 +155,13 @@ def test_all_functions_have_expected_param_names(self): for fn in self._get_all_unique_functions(): sig = inspect.signature(fn) param_names = list(sig.parameters.keys()) - # First 6 params must match exactly - assert param_names[:6] == self._BASE_PARAMS, ( - f"{fn.__name__} first 6 params {param_names[:6]} != {self._BASE_PARAMS}" - ) - # 7th param can be country_id or _kwargs - assert param_names[6] in self._VALID_7TH_PARAMS, ( - f"{fn.__name__} 7th param '{param_names[6]}' not in {self._VALID_7TH_PARAMS}" + assert param_names == self._EXPECTED_PARAMS, ( + f"{fn.__name__} params {param_names} != {self._EXPECTED_PARAMS}" ) def test_all_functions_return_none(self): for fn in self._get_all_unique_functions(): sig = inspect.signature(fn) - # `from __future__ import annotations` makes annotations strings assert sig.return_annotation in (None, "None", inspect.Parameter.empty), ( f"{fn.__name__} return annotation is {sig.return_annotation!r}, expected None" ) @@ -192,54 +174,84 @@ def _make_mock_dispatch(self, names): """Create a dispatch dict with mock functions.""" return {name: MagicMock(name=f"compute_{name}") for name in names} - def test_runs_all_when_modules_is_none(self): + def _mock_config(self): + """Create a mock CountryConfig.""" + config = MagicMock() + config.country_id = "us" + return config + + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_runs_all_when_modules_is_none(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a", "b", "c"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", None, "bl", "rf", ids[0], ids[1], ids[2], session) for fn in dispatch.values(): fn.assert_called_once_with( - "bl", "rf", ids[0], ids[1], ids[2], session, country_id="" + "bl", "rf", ids[0], ids[1], ids[2], session, config ) - def test_runs_only_requested_modules(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_runs_only_requested_modules(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a", "b", "c"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", ["b"], "bl", "rf", ids[0], ids[1], ids[2], session) dispatch["a"].assert_not_called() dispatch["b"].assert_called_once() dispatch["c"].assert_not_called() - def test_ignores_unknown_module_names(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_ignores_unknown_module_names(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] # Should not raise run_modules( - dispatch, ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session + "us", ["a", "nonexistent"], "bl", "rf", ids[0], ids[1], ids[2], session ) dispatch["a"].assert_called_once() - def test_empty_modules_list_runs_nothing(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_empty_modules_list_runs_nothing(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a", "b"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, [], "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", [], "bl", "rf", ids[0], ids[1], ids[2], session) for fn in dispatch.values(): fn.assert_not_called() - def test_preserves_call_order(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_preserves_call_order(self, mock_get_config, mock_get_dispatch): """Modules should be called in the order they appear in the modules list.""" call_order = [] + config = self._mock_config() + mock_get_config.return_value = config def make_tracker(name): def fn(*args, **kwargs): @@ -248,17 +260,24 @@ def fn(*args, **kwargs): return fn dispatch = {name: make_tracker(name) for name in ["a", "b", "c"]} + mock_get_dispatch.return_value = dispatch ids = [uuid4() for _ in range(3)] run_modules( - dispatch, ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() + "us", ["c", "a", "b"], "bl", "rf", ids[0], ids[1], ids[2], MagicMock() ) assert call_order == ["c", "a", "b"] - def test_none_modules_runs_all_in_dispatch_key_order(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_none_modules_runs_all_in_dispatch_key_order( + self, mock_get_config, mock_get_dispatch + ): """When modules is None, all dispatch entries run in dict-iteration order.""" call_order = [] + config = self._mock_config() + mock_get_config.return_value = config def make_tracker(name): def fn(*args, **kwargs): @@ -267,29 +286,40 @@ def fn(*args, **kwargs): return fn dispatch = {name: make_tracker(name) for name in ["x", "y", "z"]} + mock_get_dispatch.return_value = dispatch ids = [uuid4() for _ in range(3)] - run_modules(dispatch, None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) + run_modules("us", None, "bl", "rf", ids[0], ids[1], ids[2], MagicMock()) assert call_order == ["x", "y", "z"] - def test_passes_all_args_correctly(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_passes_config_to_module_functions( + self, mock_get_config, mock_get_dispatch + ): mock_fn = MagicMock() + config = self._mock_config() + mock_get_config.return_value = config dispatch = {"test_mod": mock_fn} + mock_get_dispatch.return_value = dispatch session = MagicMock() bl, rf, b_id, r_id, rep_id = "baseline", "reform", uuid4(), uuid4(), uuid4() - run_modules(dispatch, ["test_mod"], bl, rf, b_id, r_id, rep_id, session) + run_modules("us", ["test_mod"], bl, rf, b_id, r_id, rep_id, session) - mock_fn.assert_called_once_with( - bl, rf, b_id, r_id, rep_id, session, country_id="" - ) + mock_fn.assert_called_once_with(bl, rf, b_id, r_id, rep_id, session, config) - def test_duplicate_module_name_runs_twice(self): + @patch.object(cm, "get_dispatch_for_country") + @patch.object(cm, "get_country_config") + def test_duplicate_module_name_runs_twice(self, mock_get_config, mock_get_dispatch): dispatch = self._make_mock_dispatch(["a"]) + config = self._mock_config() + mock_get_dispatch.return_value = dispatch + mock_get_config.return_value = config session = MagicMock() ids = [uuid4() for _ in range(3)] - run_modules(dispatch, ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session) + run_modules("us", ["a", "a"], "bl", "rf", ids[0], ids[1], ids[2], session) assert dispatch["a"].call_count == 2