From 51b4a17b3739cd330d753974ea83a1b078f883a4 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 3 Feb 2026 00:55:19 +0300 Subject: [PATCH 01/19] feat: First household CRUD setup --- src/policyengine_api/api/__init__.py | 2 + src/policyengine_api/api/households.py | 119 ++++++++++++++ src/policyengine_api/models/__init__.py | 4 + src/policyengine_api/models/household.py | 54 ++++++ .../20260203000000_create_households.sql | 14 ++ test_fixtures/fixtures_households.py | 66 ++++++++ tests/test_households.py | 155 ++++++++++++++++++ 7 files changed, 414 insertions(+) create mode 100644 src/policyengine_api/api/households.py create mode 100644 src/policyengine_api/models/household.py create mode 100644 supabase/migrations/20260203000000_create_households.sql create mode 100644 test_fixtures/fixtures_households.py create mode 100644 tests/test_households.py diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 881af99..e688814 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -9,6 +9,7 @@ datasets, dynamics, household, + households, outputs, parameter_values, parameters, @@ -33,6 +34,7 @@ api_router.include_router(tax_benefit_model_versions.router) api_router.include_router(change_aggregates.router) api_router.include_router(household.router) +api_router.include_router(households.router) api_router.include_router(analysis.router) api_router.include_router(agent.router) diff --git a/src/policyengine_api/api/households.py b/src/policyengine_api/api/households.py new file mode 100644 index 0000000..fdee1f7 --- /dev/null +++ b/src/policyengine_api/api/households.py @@ -0,0 +1,119 @@ +"""Stored household CRUD endpoints. + +Households represent saved household definitions that can be reused across +calculations and impact analyses. Create a household once, then reference +it by ID for repeated simulations. + +These endpoints manage stored household *definitions* (people, entity groups, +model name, year). For running calculations on a household, use the +/household/calculate and /household/impact endpoints instead. +""" + +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import Household, HouseholdCreate, HouseholdRead +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/households", tags=["households"]) + +_ENTITY_GROUP_KEYS = ( + "tax_unit", + "family", + "spm_unit", + "marital_unit", + "household", + "benunit", +) + + +def _pack_household_data(body: HouseholdCreate) -> dict[str, Any]: + """Pack the flat request fields into a single JSON blob for storage.""" + data: dict[str, Any] = {"people": body.people} + for key in _ENTITY_GROUP_KEYS: + val = getattr(body, key) + if val is not None: + data[key] = val + return data + + +def _to_read(record: Household) -> HouseholdRead: + """Unpack the JSON blob back into the flat response shape.""" + data = record.household_data + return HouseholdRead( + id=record.id, + tax_benefit_model_name=record.tax_benefit_model_name, + year=record.year, + label=record.label, + people=data["people"], + tax_unit=data.get("tax_unit"), + family=data.get("family"), + spm_unit=data.get("spm_unit"), + marital_unit=data.get("marital_unit"), + household=data.get("household"), + benunit=data.get("benunit"), + created_at=record.created_at, + updated_at=record.updated_at, + ) + + +@router.post("/", response_model=HouseholdRead, status_code=201) +def create_household(body: HouseholdCreate, session: Session = Depends(get_session)): + """Create a stored household definition. + + The household data (people + entity groups) is persisted so it can be + retrieved later by ID. Use the returned ID with /household/calculate + or /household/impact to run simulations. + """ + record = Household( + tax_benefit_model_name=body.tax_benefit_model_name, + year=body.year, + label=body.label, + household_data=_pack_household_data(body), + ) + session.add(record) + session.commit() + session.refresh(record) + return _to_read(record) + + +@router.get("/", response_model=list[HouseholdRead]) +def list_households( + tax_benefit_model_name: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List stored households with optional filtering.""" + query = select(Household) + if tax_benefit_model_name is not None: + query = query.where(Household.tax_benefit_model_name == tax_benefit_model_name) + query = query.offset(offset).limit(limit) + records = session.exec(query).all() + return [_to_read(r) for r in records] + + +@router.get("/{household_id}", response_model=HouseholdRead) +def get_household(household_id: UUID, session: Session = Depends(get_session)): + """Get a stored household by ID.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + return _to_read(record) + + +@router.delete("/{household_id}", status_code=204) +def delete_household(household_id: UUID, session: Session = Depends(get_session)): + """Delete a stored household.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 4d64c02..7e76baa 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -11,6 +11,7 @@ from .dataset_version import DatasetVersion, DatasetVersionCreate, DatasetVersionRead from .decile_impact import DecileImpact, DecileImpactCreate, DecileImpactRead from .dynamic import Dynamic, DynamicCreate, DynamicRead +from .household import Household, HouseholdCreate, HouseholdRead from .household_job import ( HouseholdJob, HouseholdJobCreate, @@ -72,6 +73,9 @@ "Dynamic", "DynamicCreate", "DynamicRead", + "Household", + "HouseholdCreate", + "HouseholdRead", "HouseholdJob", "HouseholdJobCreate", "HouseholdJobRead", diff --git a/src/policyengine_api/models/household.py b/src/policyengine_api/models/household.py new file mode 100644 index 0000000..8a96850 --- /dev/null +++ b/src/policyengine_api/models/household.py @@ -0,0 +1,54 @@ +"""Stored household definition model.""" + +from datetime import datetime, timezone +from typing import Any, Literal +from uuid import UUID, uuid4 + +from sqlalchemy import JSON +from sqlmodel import Column, Field, SQLModel + + +class HouseholdBase(SQLModel): + """Base household fields.""" + + tax_benefit_model_name: str + year: int + label: str | None = None + household_data: dict[str, Any] = Field(sa_column=Column(JSON, nullable=False)) + + +class Household(HouseholdBase, table=True): + """Stored household database model.""" + + __tablename__ = "households" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class HouseholdCreate(SQLModel): + """Schema for creating a stored household. + + Accepts the flat structure matching the frontend Household interface: + people as an array, entity groups as optional dicts. + """ + + tax_benefit_model_name: Literal["policyengine_us", "policyengine_uk"] + year: int + label: str | None = None + people: list[dict[str, Any]] + tax_unit: dict[str, Any] | None = None + family: dict[str, Any] | None = None + spm_unit: dict[str, Any] | None = None + marital_unit: dict[str, Any] | None = None + household: dict[str, Any] | None = None + benunit: dict[str, Any] | None = None + + +class HouseholdRead(HouseholdCreate): + """Schema for reading a stored household.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/supabase/migrations/20260203000000_create_households.sql b/supabase/migrations/20260203000000_create_households.sql new file mode 100644 index 0000000..cc1907f --- /dev/null +++ b/supabase/migrations/20260203000000_create_households.sql @@ -0,0 +1,14 @@ +-- Create stored households table for persisting household definitions. + +CREATE TABLE IF NOT EXISTS households ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + tax_benefit_model_name TEXT NOT NULL, + year INTEGER NOT NULL, + label TEXT, + household_data JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX idx_households_model_name ON households (tax_benefit_model_name); +CREATE INDEX idx_households_year ON households (year); diff --git a/test_fixtures/fixtures_households.py b/test_fixtures/fixtures_households.py new file mode 100644 index 0000000..4e676f4 --- /dev/null +++ b/test_fixtures/fixtures_households.py @@ -0,0 +1,66 @@ +"""Fixtures and helpers for household CRUD tests.""" + +from policyengine_api.models import Household + +# ----------------------------------------------------------------------------- +# Request payloads (match HouseholdCreate schema) +# ----------------------------------------------------------------------------- + +MOCK_US_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "US test household", + "people": [ + {"age": 30, "employment_income": 50000}, + {"age": 28, "employment_income": 30000}, + ], + "tax_unit": {}, + "family": {}, + "household": {"state_name": "CA"}, +} + +MOCK_UK_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_uk", + "year": 2024, + "label": "UK test household", + "people": [ + {"age": 40, "employment_income": 35000}, + ], + "benunit": {"is_married": False}, + "household": {"region": "LONDON"}, +} + +MOCK_HOUSEHOLD_MINIMAL = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "people": [{"age": 25}], +} + + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", + people: list | None = None, + **entity_groups, +) -> Household: + """Create and persist a Household record.""" + household_data = {"people": people or [{"age": 30}]} + household_data.update(entity_groups) + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_households.py b/tests/test_households.py new file mode 100644 index 0000000..4c60062 --- /dev/null +++ b/tests/test_households.py @@ -0,0 +1,155 @@ +"""Tests for stored household CRUD endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_households import ( + MOCK_HOUSEHOLD_MINIMAL, + MOCK_UK_HOUSEHOLD_CREATE, + MOCK_US_HOUSEHOLD_CREATE, + create_household, +) + +# --------------------------------------------------------------------------- +# POST /households +# --------------------------------------------------------------------------- + + +def test_create_us_household(client): + """Create a US household returns 201 with id and timestamps.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["year"] == 2024 + assert data["label"] == "US test household" + + +def test_create_household_returns_people_and_entities(client): + """Created household response includes people and entity groups.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + data = response.json() + assert len(data["people"]) == 2 + assert data["people"][0]["age"] == 30 + assert data["people"][0]["employment_income"] == 50000 + assert data["household"] == {"state_name": "CA"} + assert data["tax_unit"] == {} + assert data["family"] == {} + + +def test_create_uk_household(client): + """Create a UK household with benunit.""" + response = client.post("/households", json=MOCK_UK_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert data["tax_benefit_model_name"] == "policyengine_uk" + assert data["benunit"] == {"is_married": False} + assert data["household"] == {"region": "LONDON"} + + +def test_create_household_minimal(client): + """Create a household with minimal fields.""" + response = client.post("/households", json=MOCK_HOUSEHOLD_MINIMAL) + assert response.status_code == 201 + data = response.json() + assert data["label"] is None + assert data["tax_unit"] is None + assert data["benunit"] is None + + +def test_create_household_invalid_model_name(client): + """Reject invalid tax_benefit_model_name.""" + payload = {**MOCK_HOUSEHOLD_MINIMAL, "tax_benefit_model_name": "invalid"} + response = client.post("/households", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /households/{id} +# --------------------------------------------------------------------------- + + +def test_get_household(client, session): + """Get a stored household by ID.""" + record = create_household(session) + response = client.get(f"/households/{record.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(record.id) + assert data["tax_benefit_model_name"] == "policyengine_us" + + +def test_get_household_not_found(client): + """Get a non-existent household returns 404.""" + fake_id = uuid4() + response = client.get(f"/households/{fake_id}") + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /households +# --------------------------------------------------------------------------- + + +def test_list_households_empty(client): + """List households returns empty list when none exist.""" + response = client.get("/households") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_households_with_data(client, session): + """List households returns all stored households.""" + create_household(session, label="first") + create_household(session, label="second") + response = client.get("/households") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_households_filter_by_model_name(client, session): + """Filter households by tax_benefit_model_name.""" + create_household(session, tax_benefit_model_name="policyengine_us") + create_household(session, tax_benefit_model_name="policyengine_uk") + response = client.get( + "/households", params={"tax_benefit_model_name": "policyengine_uk"} + ) + data = response.json() + assert len(data) == 1 + assert data[0]["tax_benefit_model_name"] == "policyengine_uk" + + +def test_list_households_limit_and_offset(client, session): + """Respect limit and offset pagination.""" + for i in range(5): + create_household(session, label=f"household-{i}") + response = client.get("/households", params={"limit": 2, "offset": 1}) + data = response.json() + assert len(data) == 2 + + +# --------------------------------------------------------------------------- +# DELETE /households/{id} +# --------------------------------------------------------------------------- + + +def test_delete_household(client, session): + """Delete a household returns 204.""" + record = create_household(session) + response = client.delete(f"/households/{record.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/households/{record.id}") + assert response.status_code == 404 + + +def test_delete_household_not_found(client): + """Delete a non-existent household returns 404.""" + fake_id = uuid4() + response = client.delete(f"/households/{fake_id}") + assert response.status_code == 404 From 58cd69158770dc12ad844fc8dc5f986d0ae74cf0 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 3 Feb 2026 01:16:25 +0300 Subject: [PATCH 02/19] feat: User-household associations --- src/policyengine_api/api/__init__.py | 2 + .../api/user_household_associations.py | 125 ++++++++++++ src/policyengine_api/models/__init__.py | 10 + .../models/user_household_association.py | 48 +++++ ...001_create_user_household_associations.sql | 14 ++ .../fixtures_user_household_associations.py | 62 ++++++ tests/test_user_household_associations.py | 189 ++++++++++++++++++ 7 files changed, 450 insertions(+) create mode 100644 src/policyengine_api/api/user_household_associations.py create mode 100644 src/policyengine_api/models/user_household_association.py create mode 100644 supabase/migrations/20260203000001_create_user_household_associations.sql create mode 100644 test_fixtures/fixtures_user_household_associations.py create mode 100644 tests/test_user_household_associations.py diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index e688814..92f5ea5 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -17,6 +17,7 @@ simulations, tax_benefit_model_versions, tax_benefit_models, + user_household_associations, variables, ) @@ -37,5 +38,6 @@ api_router.include_router(households.router) api_router.include_router(analysis.router) api_router.include_router(agent.router) +api_router.include_router(user_household_associations.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/user_household_associations.py b/src/policyengine_api/api/user_household_associations.py new file mode 100644 index 0000000..fa40e06 --- /dev/null +++ b/src/policyengine_api/api/user_household_associations.py @@ -0,0 +1,125 @@ +"""User-household association endpoints. + +Associations link a user to a stored household definition with metadata +(label, country). A user can have multiple associations to the same +household (e.g. different labels or configurations). +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import ( + Household, + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter( + prefix="/user-household-associations", + tags=["user-household-associations"], +) + + +@router.post("/", response_model=UserHouseholdAssociationRead, status_code=201) +def create_association( + body: UserHouseholdAssociationCreate, + session: Session = Depends(get_session), +): + """Create a user-household association.""" + household = session.get(Household, body.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {body.household_id} not found", + ) + + record = UserHouseholdAssociation( + user_id=body.user_id, + household_id=body.household_id, + country_id=body.country_id, + label=body.label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/user/{user_id}", response_model=list[UserHouseholdAssociationRead]) +def list_by_user( + user_id: UUID, + country_id: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List all associations for a user, optionally filtered by country.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id + ) + if country_id is not None: + query = query.where(UserHouseholdAssociation.country_id == country_id) + query = query.offset(offset).limit(limit) + return session.exec(query).all() + + +@router.get( + "/{user_id}/{household_id}", + response_model=list[UserHouseholdAssociationRead], +) +def list_by_user_and_household( + user_id: UUID, + household_id: UUID, + session: Session = Depends(get_session), +): + """List all associations for a specific user+household pair.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id, + UserHouseholdAssociation.household_id == household_id, + ) + return session.exec(query).all() + + +@router.put("/{association_id}", response_model=UserHouseholdAssociationRead) +def update_association( + association_id: UUID, + body: UserHouseholdAssociationUpdate, + session: Session = Depends(get_session), +): + """Update a user-household association (label).""" + record = session.get(UserHouseholdAssociation, association_id) + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + update_data = body.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + record.updated_at = datetime.now(timezone.utc) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{association_id}", status_code=204) +def delete_association( + association_id: UUID, + session: Session = Depends(get_session), +): + """Delete a user-household association.""" + record = session.get(UserHouseholdAssociation, association_id) + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 7e76baa..546c538 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -48,6 +48,12 @@ TaxBenefitModelVersionRead, ) from .user import User, UserCreate, UserRead +from .user_household_association import ( + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) from .variable import Variable, VariableCreate, VariableRead __all__ = [ @@ -114,6 +120,10 @@ "TaxBenefitModelVersionRead", "User", "UserCreate", + "UserHouseholdAssociation", + "UserHouseholdAssociationCreate", + "UserHouseholdAssociationRead", + "UserHouseholdAssociationUpdate", "UserRead", "Variable", "VariableCreate", diff --git a/src/policyengine_api/models/user_household_association.py b/src/policyengine_api/models/user_household_association.py new file mode 100644 index 0000000..208279a --- /dev/null +++ b/src/policyengine_api/models/user_household_association.py @@ -0,0 +1,48 @@ +"""User-household association model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class UserHouseholdAssociationBase(SQLModel): + """Base association fields.""" + + user_id: UUID = Field(foreign_key="users.id", index=True) + household_id: UUID = Field(foreign_key="households.id", index=True) + country_id: str + label: str | None = None + + +class UserHouseholdAssociation(UserHouseholdAssociationBase, table=True): + """User-household association database model.""" + + __tablename__ = "user_household_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserHouseholdAssociationCreate(SQLModel): + """Schema for creating a user-household association.""" + + user_id: UUID + household_id: UUID + country_id: str + label: str | None = None + + +class UserHouseholdAssociationUpdate(SQLModel): + """Schema for updating a user-household association.""" + + label: str | None = None + + +class UserHouseholdAssociationRead(UserHouseholdAssociationBase): + """Schema for reading a user-household association.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/supabase/migrations/20260203000001_create_user_household_associations.sql b/supabase/migrations/20260203000001_create_user_household_associations.sql new file mode 100644 index 0000000..3fdcb03 --- /dev/null +++ b/supabase/migrations/20260203000001_create_user_household_associations.sql @@ -0,0 +1,14 @@ +-- Create user-household associations table for linking users to saved households. + +CREATE TABLE IF NOT EXISTS user_household_associations ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + household_id UUID NOT NULL REFERENCES households(id) ON DELETE CASCADE, + country_id TEXT NOT NULL, + label TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX idx_user_household_assoc_user ON user_household_associations (user_id); +CREATE INDEX idx_user_household_assoc_household ON user_household_associations (household_id); diff --git a/test_fixtures/fixtures_user_household_associations.py b/test_fixtures/fixtures_user_household_associations.py new file mode 100644 index 0000000..66b0835 --- /dev/null +++ b/test_fixtures/fixtures_user_household_associations.py @@ -0,0 +1,62 @@ +"""Fixtures and helpers for user-household association tests.""" + +from uuid import UUID + +from policyengine_api.models import Household, User, UserHouseholdAssociation + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_user( + session, + first_name: str = "Test", + last_name: str = "User", + email: str = "test@example.com", +) -> User: + """Create and persist a User record.""" + record = User(first_name=first_name, last_name=last_name, email=email) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", +) -> Household: + """Create and persist a Household record.""" + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={"people": [{"age": 30}]}, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_association( + session, + user_id: UUID, + household_id: UUID, + country_id: str = "us", + label: str | None = "My household", +) -> UserHouseholdAssociation: + """Create and persist a UserHouseholdAssociation record.""" + record = UserHouseholdAssociation( + user_id=user_id, + household_id=household_id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_user_household_associations.py b/tests/test_user_household_associations.py new file mode 100644 index 0000000..25d8989 --- /dev/null +++ b/tests/test_user_household_associations.py @@ -0,0 +1,189 @@ +"""Tests for user-household association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_household_associations import ( + create_association, + create_household, + create_user, +) + +# --------------------------------------------------------------------------- +# POST /user-household-associations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 201 with id and timestamps.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "My US household", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user.id) + assert data["household_id"] == str(household.id) + assert data["country_id"] == "us" + assert data["label"] == "My US household" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same household are allowed.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-household-associations", json=payload) + assert r1.status_code == 201 + + payload["label"] = "Second label" + r2 = client.post("/user-household-associations", json=payload) + assert r2.status_code == 201 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_household_not_found(client, session): + """Creating with a non-existent household returns 404.""" + user = create_user(session) + payload = { + "user_id": str(user.id), + "household_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/user/{user_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get(f"/user-household-associations/user/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user = create_user(session) + h1 = create_household(session, label="H1") + h2 = create_household(session, label="H2") + create_association(session, user.id, h1.id, label="First") + create_association(session, user.id, h2.id, label="Second") + + response = client.get(f"/user-household-associations/user/{user.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, country_id="us") + create_association(session, user.id, household.id, country_id="uk") + + response = client.get( + f"/user-household-associations/user/{user.id}", + params={"country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/{user_id}/{household_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_and_household(client, session): + """List associations for a specific user+household pair.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, label="Label A") + create_association(session, user.id, household.id, label="Label B") + + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_and_household_empty(client): + """Returns empty list when no associations exist for the pair.""" + response = client.get(f"/user-household-associations/{uuid4()}/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +# --------------------------------------------------------------------------- +# PUT /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_update_association_label(client, session): + """Update label and verify updated_at changes.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id, label="Old") + + response = client.put( + f"/user-household-associations/{assoc.id}", + json={"label": "New label"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] == "New label" + + +def test_update_association_not_found(client): + """Update a non-existent association returns 404.""" + response = client.put( + f"/user-household-associations/{uuid4()}", + json={"label": "Something"}, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# DELETE /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id) + + response = client.delete(f"/user-household-associations/{assoc.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.json() == [] + + +def test_delete_association_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete(f"/user-household-associations/{uuid4()}") + assert response.status_code == 404 From 77c5a3db89d096c73a3e283f83dfaaf897185baa Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 3 Feb 2026 22:13:16 +0300 Subject: [PATCH 03/19] feat: Household analysis --- src/policyengine_api/api/analysis.py | 400 +++++++++++++++++- src/policyengine_api/models/__init__.py | 9 +- src/policyengine_api/models/report.py | 1 + src/policyengine_api/models/simulation.py | 28 +- ...203000002_simulation_household_support.sql | 16 + test_fixtures/fixtures_analysis.py | 164 +++++++ tests/test_analysis_household_impact.py | 297 +++++++++++++ 7 files changed, 901 insertions(+), 14 deletions(-) create mode 100644 supabase/migrations/20260203000002_simulation_household_support.sql create mode 100644 test_fixtures/fixtures_analysis.py create mode 100644 tests/test_analysis_household_impact.py diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c9aa86d..b1ab584 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -16,7 +16,8 @@ """ import math -from typing import Literal +from datetime import datetime, timezone +from typing import Any, Literal from uuid import UUID, uuid5 import logfire @@ -29,12 +30,15 @@ Dataset, DecileImpact, DecileImpactRead, + Household, + Policy, ProgramStatistics, ProgramStatisticsRead, Report, ReportStatus, Simulation, SimulationStatus, + SimulationType, TaxBenefitModel, TaxBenefitModelVersion, ) @@ -138,19 +142,24 @@ def _get_model_version( def _get_deterministic_simulation_id( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, + dataset_id: UUID | None = None, + household_id: UUID | None = None, ) -> UUID: """Generate a deterministic UUID from simulation parameters.""" - key = f"{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + if simulation_type == SimulationType.ECONOMY: + key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + else: + key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}" return uuid5(SIMULATION_NAMESPACE, key) def _get_deterministic_report_id( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, ) -> UUID: """Generate a deterministic UUID from report parameters.""" key = f"{baseline_sim_id}:{reform_sim_id}" @@ -158,15 +167,22 @@ def _get_deterministic_report_id( def _get_or_create_simulation( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, session: Session, + dataset_id: UUID | None = None, + household_id: UUID | None = None, ) -> Simulation: """Get existing simulation or create a new one.""" sim_id = _get_deterministic_simulation_id( - dataset_id, model_version_id, policy_id, dynamic_id + simulation_type, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + household_id=household_id, ) existing = session.get(Simulation, sim_id) @@ -175,7 +191,9 @@ def _get_or_create_simulation( simulation = Simulation( id=sim_id, + simulation_type=simulation_type, dataset_id=dataset_id, + household_id=household_id, tax_benefit_model_version_id=model_version_id, policy_id=policy_id, dynamic_id=dynamic_id, @@ -189,8 +207,9 @@ def _get_or_create_simulation( def _get_or_create_report( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, label: str, + report_type: str, session: Session, ) -> Report: """Get existing report or create a new one.""" @@ -203,6 +222,7 @@ def _get_or_create_report( report = Report( id=report_id, label=label, + report_type=report_type, baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, status=ReportStatus.PENDING, @@ -554,6 +574,362 @@ def _trigger_economy_comparison( fn.spawn(job_id=job_id, traceparent=traceparent) +# Entity types by country +UK_ENTITIES = ["person", "benunit", "household"] +US_ENTITIES = ["person", "tax_unit", "spm_unit", "family", "marital_unit", "household"] + + +def _compute_entity_diff( + baseline_list: list[dict], + reform_list: list[dict], +) -> list[dict]: + """Compute per-variable diffs for a list of entity instances.""" + entity_impact = [] + + for b_entity, r_entity in zip(baseline_list, reform_list): + entity_diff = {} + for key in b_entity: + if key in r_entity: + baseline_val = b_entity[key] + reform_val = r_entity[key] + if isinstance(baseline_val, (int, float)) and isinstance( + reform_val, (int, float) + ): + entity_diff[key] = { + "baseline": baseline_val, + "reform": reform_val, + "change": reform_val - baseline_val, + } + entity_impact.append(entity_diff) + + return entity_impact + + +def _compute_household_impact( + baseline_result: dict, + reform_result: dict, + country: str, +) -> dict[str, Any]: + """Compute difference between baseline and reform for all entity types.""" + entities = UK_ENTITIES if country == "uk" else US_ENTITIES + + impact: dict[str, Any] = {} + + for entity in entities: + if entity in baseline_result and entity in reform_result: + impact[entity] = _compute_entity_diff( + baseline_result[entity], + reform_result[entity], + ) + + return impact + + +def _ensure_list(value: Any) -> list: + """Ensure value is a list; wrap dict in list if needed.""" + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _run_household_simulation(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation and store result.""" + from policyengine_api.api.household import ( + _calculate_household_uk, + _calculate_household_us, + ) + + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + # Update status + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + # Get policy if set + policy_data = None + if simulation.policy_id: + policy = session.get(Policy, simulation.policy_id) + if policy and policy.parameter_values: + policy_data = {} + for pv in policy.parameter_values: + if pv.parameter: + param_name = pv.parameter.name + policy_data[param_name] = { + "value": pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + "start_date": pv.start_date.isoformat() + if pv.start_date + else None, + "end_date": pv.end_date.isoformat() + if pv.end_date + else None, + } + + # Extract household data with list conversion + data = household.household_data + people = data.get("people", []) + + # Run calculation based on country + if household.tax_benefit_model_name == "policyengine_uk": + result = _calculate_household_uk( + people=people, + benunit=_ensure_list(data.get("benunit")), + household=_ensure_list(data.get("household")), + year=household.year, + policy_data=policy_data, + ) + else: + result = _calculate_household_us( + people=people, + marital_unit=_ensure_list(data.get("marital_unit")), + family=_ensure_list(data.get("family")), + spm_unit=_ensure_list(data.get("spm_unit")), + tax_unit=_ensure_list(data.get("tax_unit")), + household=_ensure_list(data.get("household")), + year=household.year, + policy_data=policy_data, + ) + + # Store result + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + + session.add(simulation) + session.commit() + + +def _trigger_household_report(report_id: UUID, session: Session) -> None: + """Trigger household simulation(s) for a report.""" + report = session.get(Report, report_id) + if not report: + raise ValueError(f"Report {report_id} not found") + + # Update report status + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + try: + # Run baseline + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + if baseline_sim and baseline_sim.status == SimulationStatus.PENDING: + _run_household_simulation(baseline_sim.id, session) + + # Run reform if exists + if report.reform_simulation_id: + reform_sim = session.get(Simulation, report.reform_simulation_id) + if reform_sim and reform_sim.status == SimulationStatus.PENDING: + _run_household_simulation(reform_sim.id, session) + + # Update report status + report.status = ReportStatus.COMPLETED + except Exception as e: + report.status = ReportStatus.FAILED + report.error_message = str(e) + + session.add(report) + session.commit() + + +# Household impact request/response schemas +class HouseholdImpactRequest(BaseModel): + """Request for household impact analysis.""" + + household_id: UUID = Field(description="ID of the household to analyze") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs single calculation under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, description="Optional behavioural response specification ID" + ) + + +class HouseholdSimulationInfo(BaseModel): + """Info about a household simulation.""" + + id: UUID + status: SimulationStatus + error_message: str | None = None + + +class HouseholdImpactResponse(BaseModel): + """Response for household impact analysis.""" + + report_id: UUID + report_type: str + status: ReportStatus + baseline_simulation: HouseholdSimulationInfo | None = None + reform_simulation: HouseholdSimulationInfo | None = None + baseline_result: dict | None = None + reform_result: dict | None = None + impact: dict | None = None + error_message: str | None = None + + +def _build_household_response( + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation | None, + session: Session, +) -> HouseholdImpactResponse: + """Build response including computed impact for comparisons.""" + baseline_result = baseline_sim.household_result if baseline_sim else None + reform_result = reform_sim.household_result if reform_sim else None + + # Compute impact if comparison and both complete + impact = None + if reform_sim and baseline_result and reform_result: + # Determine country from household + household = session.get(Household, baseline_sim.household_id) + if household: + country = ( + "uk" if household.tax_benefit_model_name == "policyengine_uk" else "us" + ) + impact = _compute_household_impact(baseline_result, reform_result, country) + + return HouseholdImpactResponse( + report_id=report.id, + report_type=report.report_type or "household_single", + status=report.status, + baseline_simulation=HouseholdSimulationInfo( + id=baseline_sim.id, + status=baseline_sim.status, + error_message=baseline_sim.error_message, + ) + if baseline_sim + else None, + reform_simulation=HouseholdSimulationInfo( + id=reform_sim.id, + status=reform_sim.status, + error_message=reform_sim.error_message, + ) + if reform_sim + else None, + baseline_result=baseline_result, + reform_result=reform_result, + impact=impact, + error_message=report.error_message, + ) + + +@router.post("/household-impact", response_model=HouseholdImpactResponse) +def household_impact( + request: HouseholdImpactRequest, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Run household impact analysis. + + If policy_id is None: single run under current law. + If policy_id is set: comparison (baseline vs reform). + + This is a synchronous operation for household calculations. + """ + # Validate household exists + household = session.get(Household, request.household_id) + if not household: + raise HTTPException( + status_code=404, detail=f"Household {request.household_id} not found" + ) + + # Validate policy if provided + if request.policy_id: + policy = session.get(Policy, request.policy_id) + if not policy: + raise HTTPException( + status_code=404, detail=f"Policy {request.policy_id} not found" + ) + + # Get model version from household's tax_benefit_model_name + model_version = _get_model_version(household.tax_benefit_model_name, session) + + # Create baseline simulation + baseline_sim = _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version.id, + policy_id=None, + dynamic_id=request.dynamic_id, + session=session, + household_id=request.household_id, + ) + + # Create reform simulation if policy_id provided + reform_sim = None + if request.policy_id: + reform_sim = _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version.id, + policy_id=request.policy_id, + dynamic_id=request.dynamic_id, + session=session, + household_id=request.household_id, + ) + + # Determine report type + report_type = "household_comparison" if request.policy_id else "household_single" + + # Create report + label = f"Household impact: {household.tax_benefit_model_name}" + report = _get_or_create_report( + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id if reform_sim else None, + label=label, + report_type=report_type, + session=session, + ) + + # Trigger compute if pending + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_household_report", job_id=str(report.id)): + _trigger_household_report(report.id, session) + + return _build_household_response(report, baseline_sim, reform_sim, session) + + +@router.get("/household-impact/{report_id}", response_model=HouseholdImpactResponse) +def get_household_impact( + report_id: UUID, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Get household impact analysis status and results.""" + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id: + raise HTTPException( + status_code=500, detail="Report missing baseline simulation ID" + ) + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + if not baseline_sim: + raise HTTPException(status_code=500, detail="Baseline simulation data missing") + + reform_sim = None + if report.reform_simulation_id: + reform_sim = session.get(Simulation, report.reform_simulation_id) + + return _build_household_response(report, baseline_sim, reform_sim, session) + + @router.post("/economic-impact", response_model=EconomicImpactResponse) def economic_impact( request: EconomicImpactRequest, @@ -580,19 +956,21 @@ def economic_impact( # Get or create simulations baseline_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=None, dynamic_id=request.dynamic_id, session=session, + dataset_id=request.dataset_id, ) reform_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=request.policy_id, dynamic_id=request.dynamic_id, session=session, + dataset_id=request.dataset_id, ) # Get or create report @@ -600,7 +978,9 @@ def economic_impact( if request.policy_id: label += f" (policy {request.policy_id})" - report = _get_or_create_report(baseline_sim.id, reform_sim.id, label, session) + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) # Trigger computation if report is pending if report.status == ReportStatus.PENDING: diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 546c538..c49b457 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -36,7 +36,13 @@ ProgramStatisticsRead, ) from .report import Report, ReportCreate, ReportRead, ReportStatus -from .simulation import Simulation, SimulationCreate, SimulationRead, SimulationStatus +from .simulation import ( + Simulation, + SimulationCreate, + SimulationRead, + SimulationStatus, + SimulationType, +) from .tax_benefit_model import ( TaxBenefitModel, TaxBenefitModelCreate, @@ -112,6 +118,7 @@ "SimulationCreate", "SimulationRead", "SimulationStatus", + "SimulationType", "TaxBenefitModel", "TaxBenefitModelCreate", "TaxBenefitModelRead", diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index ee1b678..bc2cd40 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -19,6 +19,7 @@ class ReportBase(SQLModel): label: str description: str | None = None + report_type: str | None = None user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) parent_report_id: UUID | None = Field(default=None, foreign_key="reports.id") diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index b23141e..985db3e 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -1,13 +1,16 @@ from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING -from uuid import UUID, uuid4 +from typing import TYPE_CHECKING, Any +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSON from sqlmodel import Field, Relationship, SQLModel +from uuid import UUID, uuid4 if TYPE_CHECKING: from .dataset import Dataset from .dynamic import Dynamic + from .household import Household from .policy import Policy from .tax_benefit_model_version import TaxBenefitModelVersion @@ -21,10 +24,19 @@ class SimulationStatus(str, Enum): FAILED = "failed" +class SimulationType(str, Enum): + """Type of simulation.""" + + HOUSEHOLD = "household" + ECONOMY = "economy" + + class SimulationBase(SQLModel): """Base simulation fields.""" - dataset_id: UUID = Field(foreign_key="datasets.id") + simulation_type: SimulationType = SimulationType.ECONOMY + dataset_id: UUID | None = Field(default=None, foreign_key="datasets.id") + household_id: UUID | None = Field(default=None, foreign_key="households.id") policy_id: UUID | None = Field(default=None, foreign_key="policies.id") dynamic_id: UUID | None = Field(default=None, foreign_key="dynamics.id") tax_benefit_model_version_id: UUID = Field( @@ -45,6 +57,9 @@ class Simulation(SimulationBase, table=True): updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) started_at: datetime | None = None completed_at: datetime | None = None + household_result: dict[str, Any] | None = Field( + default=None, sa_column=Column(JSON) + ) # Relationships dataset: "Dataset" = Relationship( @@ -53,6 +68,12 @@ class Simulation(SimulationBase, table=True): "primaryjoin": "Simulation.dataset_id==Dataset.id", } ) + household: "Household" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[Simulation.household_id]", + "primaryjoin": "Simulation.household_id==Household.id", + } + ) policy: "Policy" = Relationship() dynamic: "Dynamic" = Relationship() tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship() @@ -78,3 +99,4 @@ class SimulationRead(SimulationBase): updated_at: datetime started_at: datetime | None completed_at: datetime | None + household_result: dict[str, Any] | None = None diff --git a/supabase/migrations/20260203000002_simulation_household_support.sql b/supabase/migrations/20260203000002_simulation_household_support.sql new file mode 100644 index 0000000..6813f07 --- /dev/null +++ b/supabase/migrations/20260203000002_simulation_household_support.sql @@ -0,0 +1,16 @@ +-- Add simulation_type as TEXT (SQLModel enum maps to text) +ALTER TABLE simulations ADD COLUMN simulation_type TEXT NOT NULL DEFAULT 'economy'; + +-- Make dataset_id nullable (was required) +ALTER TABLE simulations ALTER COLUMN dataset_id DROP NOT NULL; + +-- Add household support columns +ALTER TABLE simulations ADD COLUMN household_id UUID REFERENCES households(id); +ALTER TABLE simulations ADD COLUMN household_result JSONB; + +-- Indexes +CREATE INDEX idx_simulations_household ON simulations (household_id); +CREATE INDEX idx_simulations_type ON simulations (simulation_type); + +-- Add report_type to reports +ALTER TABLE reports ADD COLUMN report_type TEXT; diff --git a/test_fixtures/fixtures_analysis.py b/test_fixtures/fixtures_analysis.py new file mode 100644 index 0000000..d56b702 --- /dev/null +++ b/test_fixtures/fixtures_analysis.py @@ -0,0 +1,164 @@ +"""Fixtures and helpers for analysis endpoint tests.""" + +from uuid import UUID + +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +def create_tax_benefit_model( + session: Session, + name: str = "policyengine-uk", + description: str = "UK tax benefit model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + model = TaxBenefitModel( + name=name, + description=description, + ) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_model_version( + session: Session, + model_id: UUID, + version: str = "1.0.0", + description: str = "Test version", +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion record.""" + model_version = TaxBenefitModelVersion( + model_id=model_id, + version=version, + description=description, + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_parameter( + session: Session, + model_version_id: UUID, + name: str = "test_parameter", + label: str = "Test Parameter", + description: str = "A test parameter", +) -> Parameter: + """Create and persist a Parameter record.""" + param = Parameter( + tax_benefit_model_version_id=model_version_id, + name=name, + label=label, + description=description, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_policy( + session: Session, + model_version_id: UUID, + name: str = "Test Policy", + description: str = "A test policy", +) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + tax_benefit_model_version_id=model_version_id, + name=name, + description=description, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_policy_with_parameter_value( + session: Session, + model_version_id: UUID, + parameter_id: UUID, + value: float, + name: str = "Test Policy", +) -> Policy: + """Create a Policy with an associated ParameterValue.""" + policy = create_policy(session, model_version_id, name=name) + + param_value = ParameterValue( + policy_id=policy.id, + parameter_id=parameter_id, + value_json={"value": value}, + ) + session.add(param_value) + session.commit() + session.refresh(policy) + return policy + + +def create_household_for_analysis( + session: Session, + tax_benefit_model_name: str = "policyengine_uk", + year: int = 2024, + label: str = "Test household for analysis", +) -> Household: + """Create a household suitable for analysis testing.""" + if tax_benefit_model_name == "policyengine_uk": + household_data = { + "people": [{"age": 30, "employment_income": 35000}], + "benunit": {}, + "household": {"region": "LONDON"}, + } + else: + household_data = { + "people": [{"age": 30, "employment_income": 50000}], + "tax_unit": {"state_code": "CA"}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_fips": 6}, + } + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def setup_uk_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create UK model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_model_version(session, model.id) + return model, version + + +def setup_us_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create US model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_model_version(session, model.id) + return model, version diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py new file mode 100644 index 0000000..e8a614b --- /dev/null +++ b/tests/test_analysis_household_impact.py @@ -0,0 +1,297 @@ +"""Tests for household impact analysis endpoints.""" + +from uuid import uuid4 + +import pytest + +from test_fixtures.fixtures_analysis import ( + create_household_for_analysis, + create_policy, + setup_uk_model_and_version, + setup_us_model_and_version, +) +from policyengine_api.models import Report, ReportStatus, Simulation, SimulationType + + +# --------------------------------------------------------------------------- +# Validation tests (no database required beyond session fixture) +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactValidation: + """Tests for request validation.""" + + def test_missing_household_id(self, client): + """Test that missing household_id returns 422.""" + response = client.post( + "/analysis/household-impact", + json={}, + ) + assert response.status_code == 422 + + def test_invalid_uuid(self, client): + """Test that invalid UUID returns 422.""" + response = client.post( + "/analysis/household-impact", + json={ + "household_id": "not-a-uuid", + }, + ) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# 404 tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactNotFound: + """Tests for 404 responses.""" + + def test_household_not_found(self, client, session): + """Test that non-existent household returns 404.""" + # Need model for the model version lookup + setup_uk_model_and_version(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_policy_not_found(self, client, session): + """Test that non-existent policy returns 404.""" + setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_get_report_not_found(self, client): + """Test that GET with non-existent report_id returns 404.""" + response = client.get(f"/analysis/household-impact/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# Record creation tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactRecordCreation: + """Tests for correct record creation.""" + + def test_single_run_creates_one_simulation(self, client, session): + """Single run (no policy_id) creates one simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + # May fail during calculation since policyengine not available, + # but should create records + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is None + + def test_comparison_creates_two_simulations(self, client, session): + """Comparison (with policy_id) creates two simulations.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, version.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_comparison" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is not None + + def test_simulation_type_is_household(self, client, session): + """Created simulations have simulation_type=HOUSEHOLD.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + data = response.json() + + # Check simulation in database + sim_id = data["baseline_simulation"]["id"] + sim = session.get(Simulation, sim_id) + assert sim is not None + assert sim.simulation_type == SimulationType.HOUSEHOLD + assert sim.household_id == household.id + assert sim.dataset_id is None + + def test_report_links_simulations(self, client, session): + """Report correctly links baseline and reform simulations.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, version.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + + # Check report in database + report = session.get(Report, data["report_id"]) + assert report is not None + assert report.baseline_simulation_id == data["baseline_simulation"]["id"] + assert report.reform_simulation_id == data["reform_simulation"]["id"] + assert report.report_type == "household_comparison" + + +# --------------------------------------------------------------------------- +# Deduplication tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactDeduplication: + """Tests for simulation/report deduplication.""" + + def test_same_request_returns_same_simulation(self, client, session): + """Same household + same parameters returns same simulation ID.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # First request + response1 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data1 = response1.json() + + # Second request with same parameters + response2 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data2 = response2.json() + + # Should return same IDs + assert data1["report_id"] == data2["report_id"] + assert data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + + def test_different_policy_creates_different_simulation(self, client, session): + """Different policy creates different simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy1 = create_policy(session, version.id, name="Policy 1") + policy2 = create_policy(session, version.id, name="Policy 2") + + # Request with policy1 + response1 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy1.id), + }, + ) + data1 = response1.json() + + # Request with policy2 + response2 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy2.id), + }, + ) + data2 = response2.json() + + # Reports should be different + assert data1["report_id"] != data2["report_id"] + # Reform simulations should be different + assert ( + data1["reform_simulation"]["id"] != data2["reform_simulation"]["id"] + ) + # Baseline simulations should be the same (same household, no policy) + assert ( + data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + ) + + +# --------------------------------------------------------------------------- +# GET endpoint tests +# --------------------------------------------------------------------------- + + +class TestGetHouseholdImpact: + """Tests for GET /analysis/household-impact/{report_id}.""" + + def test_get_returns_report_data(self, client, session): + """GET returns report with simulation info.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # Create report via POST + post_response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + report_id = post_response.json()["report_id"] + + # GET the report + get_response = client.get(f"/analysis/household-impact/{report_id}") + assert get_response.status_code == 200 + + data = get_response.json() + assert data["report_id"] == report_id + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + + +# --------------------------------------------------------------------------- +# US household tests +# --------------------------------------------------------------------------- + + +class TestUSHouseholdImpact: + """Tests specific to US households.""" + + def test_us_household_creates_simulation(self, client, session): + """US household creates simulation with correct model.""" + _, version = setup_us_model_and_version(session) + household = create_household_for_analysis( + session, tax_benefit_model_name="policyengine_us" + ) + + response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data = response.json() + assert "report_id" in data + assert data["baseline_simulation"] is not None From 3c28466a9623f3266a3430a910a8968383e99ad3 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 3 Feb 2026 22:34:31 +0300 Subject: [PATCH 04/19] fix: Improve code quality --- src/policyengine_api/api/__init__.py | 2 + src/policyengine_api/api/analysis.py | 361 +--------- .../api/household_analysis.py | 640 ++++++++++++++++++ 3 files changed, 643 insertions(+), 360 deletions(-) create mode 100644 src/policyengine_api/api/household_analysis.py diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 92f5ea5..c3e0353 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -9,6 +9,7 @@ datasets, dynamics, household, + household_analysis, households, outputs, parameter_values, @@ -35,6 +36,7 @@ api_router.include_router(tax_benefit_model_versions.router) api_router.include_router(change_aggregates.router) api_router.include_router(household.router) +api_router.include_router(household_analysis.router) api_router.include_router(households.router) api_router.include_router(analysis.router) api_router.include_router(agent.router) diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index b1ab584..10e6fc5 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -16,8 +16,7 @@ """ import math -from datetime import datetime, timezone -from typing import Any, Literal +from typing import Literal from uuid import UUID, uuid5 import logfire @@ -30,8 +29,6 @@ Dataset, DecileImpact, DecileImpactRead, - Household, - Policy, ProgramStatistics, ProgramStatisticsRead, Report, @@ -574,362 +571,6 @@ def _trigger_economy_comparison( fn.spawn(job_id=job_id, traceparent=traceparent) -# Entity types by country -UK_ENTITIES = ["person", "benunit", "household"] -US_ENTITIES = ["person", "tax_unit", "spm_unit", "family", "marital_unit", "household"] - - -def _compute_entity_diff( - baseline_list: list[dict], - reform_list: list[dict], -) -> list[dict]: - """Compute per-variable diffs for a list of entity instances.""" - entity_impact = [] - - for b_entity, r_entity in zip(baseline_list, reform_list): - entity_diff = {} - for key in b_entity: - if key in r_entity: - baseline_val = b_entity[key] - reform_val = r_entity[key] - if isinstance(baseline_val, (int, float)) and isinstance( - reform_val, (int, float) - ): - entity_diff[key] = { - "baseline": baseline_val, - "reform": reform_val, - "change": reform_val - baseline_val, - } - entity_impact.append(entity_diff) - - return entity_impact - - -def _compute_household_impact( - baseline_result: dict, - reform_result: dict, - country: str, -) -> dict[str, Any]: - """Compute difference between baseline and reform for all entity types.""" - entities = UK_ENTITIES if country == "uk" else US_ENTITIES - - impact: dict[str, Any] = {} - - for entity in entities: - if entity in baseline_result and entity in reform_result: - impact[entity] = _compute_entity_diff( - baseline_result[entity], - reform_result[entity], - ) - - return impact - - -def _ensure_list(value: Any) -> list: - """Ensure value is a list; wrap dict in list if needed.""" - if value is None: - return [] - if isinstance(value, list): - return value - return [value] - - -def _run_household_simulation(simulation_id: UUID, session: Session) -> None: - """Run a single household simulation and store result.""" - from policyengine_api.api.household import ( - _calculate_household_uk, - _calculate_household_us, - ) - - simulation = session.get(Simulation, simulation_id) - if not simulation: - raise ValueError(f"Simulation {simulation_id} not found") - - household = session.get(Household, simulation.household_id) - if not household: - raise ValueError(f"Household {simulation.household_id} not found") - - # Update status - simulation.status = SimulationStatus.RUNNING - simulation.started_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - - try: - # Get policy if set - policy_data = None - if simulation.policy_id: - policy = session.get(Policy, simulation.policy_id) - if policy and policy.parameter_values: - policy_data = {} - for pv in policy.parameter_values: - if pv.parameter: - param_name = pv.parameter.name - policy_data[param_name] = { - "value": pv.value_json.get("value") - if isinstance(pv.value_json, dict) - else pv.value_json, - "start_date": pv.start_date.isoformat() - if pv.start_date - else None, - "end_date": pv.end_date.isoformat() - if pv.end_date - else None, - } - - # Extract household data with list conversion - data = household.household_data - people = data.get("people", []) - - # Run calculation based on country - if household.tax_benefit_model_name == "policyengine_uk": - result = _calculate_household_uk( - people=people, - benunit=_ensure_list(data.get("benunit")), - household=_ensure_list(data.get("household")), - year=household.year, - policy_data=policy_data, - ) - else: - result = _calculate_household_us( - people=people, - marital_unit=_ensure_list(data.get("marital_unit")), - family=_ensure_list(data.get("family")), - spm_unit=_ensure_list(data.get("spm_unit")), - tax_unit=_ensure_list(data.get("tax_unit")), - household=_ensure_list(data.get("household")), - year=household.year, - policy_data=policy_data, - ) - - # Store result - simulation.household_result = result - simulation.status = SimulationStatus.COMPLETED - simulation.completed_at = datetime.now(timezone.utc) - - except Exception as e: - simulation.status = SimulationStatus.FAILED - simulation.error_message = str(e) - simulation.completed_at = datetime.now(timezone.utc) - - session.add(simulation) - session.commit() - - -def _trigger_household_report(report_id: UUID, session: Session) -> None: - """Trigger household simulation(s) for a report.""" - report = session.get(Report, report_id) - if not report: - raise ValueError(f"Report {report_id} not found") - - # Update report status - report.status = ReportStatus.RUNNING - session.add(report) - session.commit() - - try: - # Run baseline - baseline_sim = session.get(Simulation, report.baseline_simulation_id) - if baseline_sim and baseline_sim.status == SimulationStatus.PENDING: - _run_household_simulation(baseline_sim.id, session) - - # Run reform if exists - if report.reform_simulation_id: - reform_sim = session.get(Simulation, report.reform_simulation_id) - if reform_sim and reform_sim.status == SimulationStatus.PENDING: - _run_household_simulation(reform_sim.id, session) - - # Update report status - report.status = ReportStatus.COMPLETED - except Exception as e: - report.status = ReportStatus.FAILED - report.error_message = str(e) - - session.add(report) - session.commit() - - -# Household impact request/response schemas -class HouseholdImpactRequest(BaseModel): - """Request for household impact analysis.""" - - household_id: UUID = Field(description="ID of the household to analyze") - policy_id: UUID | None = Field( - default=None, - description="Reform policy ID. If None, runs single calculation under current law.", - ) - dynamic_id: UUID | None = Field( - default=None, description="Optional behavioural response specification ID" - ) - - -class HouseholdSimulationInfo(BaseModel): - """Info about a household simulation.""" - - id: UUID - status: SimulationStatus - error_message: str | None = None - - -class HouseholdImpactResponse(BaseModel): - """Response for household impact analysis.""" - - report_id: UUID - report_type: str - status: ReportStatus - baseline_simulation: HouseholdSimulationInfo | None = None - reform_simulation: HouseholdSimulationInfo | None = None - baseline_result: dict | None = None - reform_result: dict | None = None - impact: dict | None = None - error_message: str | None = None - - -def _build_household_response( - report: Report, - baseline_sim: Simulation, - reform_sim: Simulation | None, - session: Session, -) -> HouseholdImpactResponse: - """Build response including computed impact for comparisons.""" - baseline_result = baseline_sim.household_result if baseline_sim else None - reform_result = reform_sim.household_result if reform_sim else None - - # Compute impact if comparison and both complete - impact = None - if reform_sim and baseline_result and reform_result: - # Determine country from household - household = session.get(Household, baseline_sim.household_id) - if household: - country = ( - "uk" if household.tax_benefit_model_name == "policyengine_uk" else "us" - ) - impact = _compute_household_impact(baseline_result, reform_result, country) - - return HouseholdImpactResponse( - report_id=report.id, - report_type=report.report_type or "household_single", - status=report.status, - baseline_simulation=HouseholdSimulationInfo( - id=baseline_sim.id, - status=baseline_sim.status, - error_message=baseline_sim.error_message, - ) - if baseline_sim - else None, - reform_simulation=HouseholdSimulationInfo( - id=reform_sim.id, - status=reform_sim.status, - error_message=reform_sim.error_message, - ) - if reform_sim - else None, - baseline_result=baseline_result, - reform_result=reform_result, - impact=impact, - error_message=report.error_message, - ) - - -@router.post("/household-impact", response_model=HouseholdImpactResponse) -def household_impact( - request: HouseholdImpactRequest, - session: Session = Depends(get_session), -) -> HouseholdImpactResponse: - """Run household impact analysis. - - If policy_id is None: single run under current law. - If policy_id is set: comparison (baseline vs reform). - - This is a synchronous operation for household calculations. - """ - # Validate household exists - household = session.get(Household, request.household_id) - if not household: - raise HTTPException( - status_code=404, detail=f"Household {request.household_id} not found" - ) - - # Validate policy if provided - if request.policy_id: - policy = session.get(Policy, request.policy_id) - if not policy: - raise HTTPException( - status_code=404, detail=f"Policy {request.policy_id} not found" - ) - - # Get model version from household's tax_benefit_model_name - model_version = _get_model_version(household.tax_benefit_model_name, session) - - # Create baseline simulation - baseline_sim = _get_or_create_simulation( - simulation_type=SimulationType.HOUSEHOLD, - model_version_id=model_version.id, - policy_id=None, - dynamic_id=request.dynamic_id, - session=session, - household_id=request.household_id, - ) - - # Create reform simulation if policy_id provided - reform_sim = None - if request.policy_id: - reform_sim = _get_or_create_simulation( - simulation_type=SimulationType.HOUSEHOLD, - model_version_id=model_version.id, - policy_id=request.policy_id, - dynamic_id=request.dynamic_id, - session=session, - household_id=request.household_id, - ) - - # Determine report type - report_type = "household_comparison" if request.policy_id else "household_single" - - # Create report - label = f"Household impact: {household.tax_benefit_model_name}" - report = _get_or_create_report( - baseline_sim_id=baseline_sim.id, - reform_sim_id=reform_sim.id if reform_sim else None, - label=label, - report_type=report_type, - session=session, - ) - - # Trigger compute if pending - if report.status == ReportStatus.PENDING: - with logfire.span("trigger_household_report", job_id=str(report.id)): - _trigger_household_report(report.id, session) - - return _build_household_response(report, baseline_sim, reform_sim, session) - - -@router.get("/household-impact/{report_id}", response_model=HouseholdImpactResponse) -def get_household_impact( - report_id: UUID, - session: Session = Depends(get_session), -) -> HouseholdImpactResponse: - """Get household impact analysis status and results.""" - report = session.get(Report, report_id) - if not report: - raise HTTPException(status_code=404, detail=f"Report {report_id} not found") - - if not report.baseline_simulation_id: - raise HTTPException( - status_code=500, detail="Report missing baseline simulation ID" - ) - - baseline_sim = session.get(Simulation, report.baseline_simulation_id) - if not baseline_sim: - raise HTTPException(status_code=500, detail="Baseline simulation data missing") - - reform_sim = None - if report.reform_simulation_id: - reform_sim = session.get(Simulation, report.reform_simulation_id) - - return _build_household_response(report, baseline_sim, reform_sim, session) - - @router.post("/economic-impact", response_model=EconomicImpactResponse) def economic_impact( request: EconomicImpactRequest, diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py new file mode 100644 index 0000000..29ea89e --- /dev/null +++ b/src/policyengine_api/api/household_analysis.py @@ -0,0 +1,640 @@ +"""Household impact analysis endpoints. + +Use these endpoints to analyze household-level effects of policy reforms. +Supports single runs (current law) and comparisons (baseline vs reform). + +WORKFLOW: +1. Create a stored household: POST /households +2. Optionally create a reform policy: POST /policies +3. Run analysis: POST /analysis/household-impact +4. Results are synchronous - the response includes computed values +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Protocol +from uuid import UUID + +import logfire +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Policy, + Report, + ReportStatus, + Simulation, + SimulationStatus, + SimulationType, +) +from policyengine_api.services.database import get_session + +from .analysis import ( + _get_model_version, + _get_or_create_report, + _get_or_create_simulation, +) + +router = APIRouter(prefix="/analysis", tags=["analysis"]) + + +# ============================================================================= +# Country Strategy Pattern +# ============================================================================= + + +@dataclass(frozen=True) +class CountryConfig: + """Configuration for a country's household calculation.""" + + name: str + entity_types: tuple[str, ...] + + +UK_CONFIG = CountryConfig( + name="uk", + entity_types=("person", "benunit", "household"), +) + +US_CONFIG = CountryConfig( + name="us", + entity_types=("person", "tax_unit", "spm_unit", "family", "marital_unit", "household"), +) + + +def get_country_config(tax_benefit_model_name: str) -> CountryConfig: + """Get country configuration from model name.""" + if tax_benefit_model_name == "policyengine_uk": + return UK_CONFIG + return US_CONFIG + + +class HouseholdCalculator(Protocol): + """Protocol for country-specific household calculators.""" + + def __call__( + self, + household_data: dict[str, Any], + year: int, + policy_data: dict | None, + ) -> dict: ... + + +def calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate UK household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_uk + + return _calculate_household_uk( + people=household_data.get("people", []), + benunit=_ensure_list(household_data.get("benunit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate US household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_us + + return _calculate_household_us( + people=household_data.get("people", []), + marital_unit=_ensure_list(household_data.get("marital_unit")), + family=_ensure_list(household_data.get("family")), + spm_unit=_ensure_list(household_data.get("spm_unit")), + tax_unit=_ensure_list(household_data.get("tax_unit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def get_calculator(tax_benefit_model_name: str) -> HouseholdCalculator: + """Get the appropriate calculator for a country.""" + if tax_benefit_model_name == "policyengine_uk": + return calculate_uk_household + return calculate_us_household + + +# ============================================================================= +# Data Transformation Helpers +# ============================================================================= + + +def _ensure_list(value: Any) -> list: + """Ensure value is a list; wrap dict in list if needed.""" + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _extract_policy_data(policy: Policy | None) -> dict | None: + """Extract policy data from a Policy model into calculation format.""" + if not policy or not policy.parameter_values: + return None + + policy_data = {} + for pv in policy.parameter_values: + if not pv.parameter: + continue + + policy_data[pv.parameter.name] = { + "value": _extract_value(pv.value_json), + "start_date": _format_date(pv.start_date), + "end_date": _format_date(pv.end_date), + } + + return policy_data if policy_data else None + + +def _extract_value(value_json: Any) -> Any: + """Extract the actual value from value_json.""" + if isinstance(value_json, dict): + return value_json.get("value") + return value_json + + +def _format_date(date: Any) -> str | None: + """Format a date for the policy data structure.""" + if date is None: + return None + if hasattr(date, "isoformat"): + return date.isoformat() + return str(date) + + +# ============================================================================= +# Impact Computation +# ============================================================================= + + +def compute_variable_diff(baseline_val: Any, reform_val: Any) -> dict | None: + """Compute diff for a single variable if both are numeric.""" + if not isinstance(baseline_val, (int, float)): + return None + if not isinstance(reform_val, (int, float)): + return None + + return { + "baseline": baseline_val, + "reform": reform_val, + "change": reform_val - baseline_val, + } + + +def compute_entity_diff(baseline_entity: dict, reform_entity: dict) -> dict: + """Compute per-variable diffs for a single entity instance.""" + entity_diff = {} + + for key, baseline_val in baseline_entity.items(): + reform_val = reform_entity.get(key) + if reform_val is None: + continue + + diff = compute_variable_diff(baseline_val, reform_val) + if diff is not None: + entity_diff[key] = diff + + return entity_diff + + +def compute_entity_list_diff( + baseline_list: list[dict], + reform_list: list[dict], +) -> list[dict]: + """Compute diffs for a list of entity instances.""" + return [ + compute_entity_diff(b_entity, r_entity) + for b_entity, r_entity in zip(baseline_list, reform_list) + ] + + +def compute_household_impact( + baseline_result: dict, + reform_result: dict, + config: CountryConfig, +) -> dict[str, Any]: + """Compute difference between baseline and reform for all entity types.""" + impact: dict[str, Any] = {} + + for entity in config.entity_types: + baseline_entities = baseline_result.get(entity) + reform_entities = reform_result.get(entity) + + if baseline_entities is None or reform_entities is None: + continue + + impact[entity] = compute_entity_list_diff(baseline_entities, reform_entities) + + return impact + + +# ============================================================================= +# Simulation Execution +# ============================================================================= + + +def mark_simulation_running(simulation: Simulation, session: Session) -> None: + """Mark a simulation as running.""" + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_completed( + simulation: Simulation, + result: dict, + session: Session, +) -> None: + """Mark a simulation as completed with result.""" + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_failed( + simulation: Simulation, + error: Exception, + session: Session, +) -> None: + """Mark a simulation as failed with error.""" + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(error) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def run_household_simulation(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation and store result.""" + simulation = _load_simulation(simulation_id, session) + household = _load_household(simulation.household_id, session) + policy_data = _load_policy_data(simulation.policy_id, session) + + mark_simulation_running(simulation, session) + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + mark_simulation_completed(simulation, result, session) + except Exception as e: + mark_simulation_failed(simulation, e, session) + + +def _load_simulation(simulation_id: UUID, session: Session) -> Simulation: + """Load simulation or raise error.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + return simulation + + +def _load_household(household_id: UUID | None, session: Session) -> Household: + """Load household or raise error.""" + if not household_id: + raise ValueError("Simulation has no household_id") + + household = session.get(Household, household_id) + if not household: + raise ValueError(f"Household {household_id} not found") + return household + + +def _load_policy_data(policy_id: UUID | None, session: Session) -> dict | None: + """Load and extract policy data if policy_id is set.""" + if not policy_id: + return None + + policy = session.get(Policy, policy_id) + return _extract_policy_data(policy) + + +# ============================================================================= +# Report Orchestration +# ============================================================================= + + +def trigger_household_report(report_id: UUID, session: Session) -> None: + """Trigger household simulation(s) for a report.""" + report = _load_report(report_id, session) + _mark_report_running(report, session) + + try: + _run_report_simulations(report, session) + _mark_report_completed(report, session) + except Exception as e: + _mark_report_failed(report, e, session) + + +def _load_report(report_id: UUID, session: Session) -> Report: + """Load report or raise error.""" + report = session.get(Report, report_id) + if not report: + raise ValueError(f"Report {report_id} not found") + return report + + +def _mark_report_running(report: Report, session: Session) -> None: + """Mark report as running.""" + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + +def _mark_report_completed(report: Report, session: Session) -> None: + """Mark report as completed.""" + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + +def _mark_report_failed(report: Report, error: Exception, session: Session) -> None: + """Mark report as failed.""" + report.status = ReportStatus.FAILED + report.error_message = str(error) + session.add(report) + session.commit() + + +def _run_report_simulations(report: Report, session: Session) -> None: + """Run all pending simulations for a report.""" + _run_simulation_if_pending(report.baseline_simulation_id, session) + + if report.reform_simulation_id: + _run_simulation_if_pending(report.reform_simulation_id, session) + + +def _run_simulation_if_pending(simulation_id: UUID | None, session: Session) -> None: + """Run simulation if it exists and is pending.""" + if not simulation_id: + return + + simulation = session.get(Simulation, simulation_id) + if simulation and simulation.status == SimulationStatus.PENDING: + run_household_simulation(simulation.id, session) + + +# ============================================================================= +# Request/Response Schemas +# ============================================================================= + + +class HouseholdImpactRequest(BaseModel): + """Request for household impact analysis.""" + + household_id: UUID = Field(description="ID of the household to analyze") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs single calculation under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationInfo(BaseModel): + """Info about a household simulation.""" + + id: UUID + status: SimulationStatus + error_message: str | None = None + + +class HouseholdImpactResponse(BaseModel): + """Response for household impact analysis.""" + + report_id: UUID + report_type: str + status: ReportStatus + baseline_simulation: HouseholdSimulationInfo | None = None + reform_simulation: HouseholdSimulationInfo | None = None + baseline_result: dict | None = None + reform_result: dict | None = None + impact: dict | None = None + error_message: str | None = None + + +# ============================================================================= +# Response Building +# ============================================================================= + + +def build_simulation_info(simulation: Simulation | None) -> HouseholdSimulationInfo | None: + """Build simulation info from a simulation.""" + if not simulation: + return None + + return HouseholdSimulationInfo( + id=simulation.id, + status=simulation.status, + error_message=simulation.error_message, + ) + + +def build_household_response( + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation | None, + session: Session, +) -> HouseholdImpactResponse: + """Build response including computed impact for comparisons.""" + baseline_result = baseline_sim.household_result + reform_result = reform_sim.household_result if reform_sim else None + + impact = _compute_impact_if_comparison( + baseline_sim, reform_sim, baseline_result, reform_result, session + ) + + return HouseholdImpactResponse( + report_id=report.id, + report_type=report.report_type or "household_single", + status=report.status, + baseline_simulation=build_simulation_info(baseline_sim), + reform_simulation=build_simulation_info(reform_sim), + baseline_result=baseline_result, + reform_result=reform_result, + impact=impact, + error_message=report.error_message, + ) + + +def _compute_impact_if_comparison( + baseline_sim: Simulation, + reform_sim: Simulation | None, + baseline_result: dict | None, + reform_result: dict | None, + session: Session, +) -> dict | None: + """Compute impact only if this is a comparison with both results.""" + if not reform_sim: + return None + if not baseline_result or not reform_result: + return None + + household = session.get(Household, baseline_sim.household_id) + if not household: + return None + + config = get_country_config(household.tax_benefit_model_name) + return compute_household_impact(baseline_result, reform_result, config) + + +# ============================================================================= +# Validation Helpers +# ============================================================================= + + +def validate_household_exists(household_id: UUID, session: Session) -> Household: + """Validate household exists and return it.""" + household = session.get(Household, household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {household_id} not found", + ) + return household + + +def validate_policy_exists(policy_id: UUID | None, session: Session) -> None: + """Validate policy exists if provided.""" + if not policy_id: + return + + policy = session.get(Policy, policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {policy_id} not found", + ) + + +# ============================================================================= +# Endpoints +# ============================================================================= + + +@router.post("/household-impact", response_model=HouseholdImpactResponse) +def household_impact( + request: HouseholdImpactRequest, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Run household impact analysis. + + If policy_id is None: single run under current law. + If policy_id is set: comparison (baseline vs reform). + + This is a synchronous operation for household calculations. + """ + household = validate_household_exists(request.household_id, session) + validate_policy_exists(request.policy_id, session) + + model_version = _get_model_version(household.tax_benefit_model_name, session) + + baseline_sim = _create_baseline_simulation( + household, model_version.id, request.dynamic_id, session + ) + reform_sim = _create_reform_simulation( + household, model_version.id, request.policy_id, request.dynamic_id, session + ) + + report_type = "household_comparison" if request.policy_id else "household_single" + report = _get_or_create_report( + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id if reform_sim else None, + label=f"Household impact: {household.tax_benefit_model_name}", + report_type=report_type, + session=session, + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_household_report", job_id=str(report.id)): + trigger_household_report(report.id, session) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +@router.get("/household-impact/{report_id}", response_model=HouseholdImpactResponse) +def get_household_impact( + report_id: UUID, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Get household impact analysis status and results.""" + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id: + raise HTTPException( + status_code=500, + detail="Report missing baseline simulation ID", + ) + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + if not baseline_sim: + raise HTTPException(status_code=500, detail="Baseline simulation data missing") + + reform_sim = None + if report.reform_simulation_id: + reform_sim = session.get(Simulation, report.reform_simulation_id) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +# ============================================================================= +# Simulation Creation Helpers +# ============================================================================= + + +def _create_baseline_simulation( + household: Household, + model_version_id: UUID, + dynamic_id: UUID | None, + session: Session, +) -> Simulation: + """Create baseline simulation (current law, no policy).""" + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=None, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) + + +def _create_reform_simulation( + household: Household, + model_version_id: UUID, + policy_id: UUID | None, + dynamic_id: UUID | None, + session: Session, +) -> Simulation | None: + """Create reform simulation if policy_id is provided.""" + if not policy_id: + return None + + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=policy_id, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) From 78f87804a30dd3c56419fb4c80842f153d1cceac Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 3 Feb 2026 22:59:28 +0300 Subject: [PATCH 05/19] test: Add tests --- test_fixtures/fixtures_analysis.py | 164 --------- test_fixtures/fixtures_household_analysis.py | 366 +++++++++++++++++++ tests/test_analysis_household_impact.py | 245 ++++++++++++- 3 files changed, 603 insertions(+), 172 deletions(-) delete mode 100644 test_fixtures/fixtures_analysis.py create mode 100644 test_fixtures/fixtures_household_analysis.py diff --git a/test_fixtures/fixtures_analysis.py b/test_fixtures/fixtures_analysis.py deleted file mode 100644 index d56b702..0000000 --- a/test_fixtures/fixtures_analysis.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Fixtures and helpers for analysis endpoint tests.""" - -from uuid import UUID - -from sqlmodel import Session - -from policyengine_api.models import ( - Household, - Parameter, - ParameterValue, - Policy, - TaxBenefitModel, - TaxBenefitModelVersion, -) - - -def create_tax_benefit_model( - session: Session, - name: str = "policyengine-uk", - description: str = "UK tax benefit model", -) -> TaxBenefitModel: - """Create and persist a TaxBenefitModel record.""" - model = TaxBenefitModel( - name=name, - description=description, - ) - session.add(model) - session.commit() - session.refresh(model) - return model - - -def create_model_version( - session: Session, - model_id: UUID, - version: str = "1.0.0", - description: str = "Test version", -) -> TaxBenefitModelVersion: - """Create and persist a TaxBenefitModelVersion record.""" - model_version = TaxBenefitModelVersion( - model_id=model_id, - version=version, - description=description, - ) - session.add(model_version) - session.commit() - session.refresh(model_version) - return model_version - - -def create_parameter( - session: Session, - model_version_id: UUID, - name: str = "test_parameter", - label: str = "Test Parameter", - description: str = "A test parameter", -) -> Parameter: - """Create and persist a Parameter record.""" - param = Parameter( - tax_benefit_model_version_id=model_version_id, - name=name, - label=label, - description=description, - ) - session.add(param) - session.commit() - session.refresh(param) - return param - - -def create_policy( - session: Session, - model_version_id: UUID, - name: str = "Test Policy", - description: str = "A test policy", -) -> Policy: - """Create and persist a Policy record.""" - policy = Policy( - tax_benefit_model_version_id=model_version_id, - name=name, - description=description, - ) - session.add(policy) - session.commit() - session.refresh(policy) - return policy - - -def create_policy_with_parameter_value( - session: Session, - model_version_id: UUID, - parameter_id: UUID, - value: float, - name: str = "Test Policy", -) -> Policy: - """Create a Policy with an associated ParameterValue.""" - policy = create_policy(session, model_version_id, name=name) - - param_value = ParameterValue( - policy_id=policy.id, - parameter_id=parameter_id, - value_json={"value": value}, - ) - session.add(param_value) - session.commit() - session.refresh(policy) - return policy - - -def create_household_for_analysis( - session: Session, - tax_benefit_model_name: str = "policyengine_uk", - year: int = 2024, - label: str = "Test household for analysis", -) -> Household: - """Create a household suitable for analysis testing.""" - if tax_benefit_model_name == "policyengine_uk": - household_data = { - "people": [{"age": 30, "employment_income": 35000}], - "benunit": {}, - "household": {"region": "LONDON"}, - } - else: - household_data = { - "people": [{"age": 30, "employment_income": 50000}], - "tax_unit": {"state_code": "CA"}, - "family": {}, - "spm_unit": {}, - "marital_unit": {}, - "household": {"state_fips": 6}, - } - - record = Household( - tax_benefit_model_name=tax_benefit_model_name, - year=year, - label=label, - household_data=household_data, - ) - session.add(record) - session.commit() - session.refresh(record) - return record - - -def setup_uk_model_and_version( - session: Session, -) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: - """Create UK model and version for testing.""" - model = create_tax_benefit_model( - session, name="policyengine-uk", description="UK model" - ) - version = create_model_version(session, model.id) - return model, version - - -def setup_us_model_and_version( - session: Session, -) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: - """Create US model and version for testing.""" - model = create_tax_benefit_model( - session, name="policyengine-us", description="US model" - ) - version = create_model_version(session, model.id) - return model, version diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py new file mode 100644 index 0000000..573930a --- /dev/null +++ b/test_fixtures/fixtures_household_analysis.py @@ -0,0 +1,366 @@ +"""Fixtures and helpers for household analysis endpoint tests.""" + +from typing import Any +from unittest.mock import patch +from uuid import UUID + +import pytest +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +# ============================================================================= +# Sample Calculation Results +# ============================================================================= + + +SAMPLE_UK_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4500.0, + "national_insurance": 2800.0, + "net_income": 27700.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_UK_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4000.0, + "national_insurance": 2800.0, + "net_income": 28200.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_US_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 6000.0, + "fica": 3825.0, + "net_income": 40175.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +SAMPLE_US_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 5500.0, + "fica": 3825.0, + "net_income": 40675.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +# ============================================================================= +# Mock Calculator Functions +# ============================================================================= + + +def mock_calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock UK calculator that returns sample results.""" + if policy_data: + return SAMPLE_UK_REFORM_RESULT + return SAMPLE_UK_BASELINE_RESULT + + +def mock_calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock US calculator that returns sample results.""" + if policy_data: + return SAMPLE_US_REFORM_RESULT + return SAMPLE_US_BASELINE_RESULT + + +def mock_calculate_household_failing( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock calculator that raises an exception.""" + raise RuntimeError("Calculation failed") + + +# ============================================================================= +# Pytest Fixtures for Mocking +# ============================================================================= + + +@pytest.fixture +def mock_uk_calculator(): + """Fixture that patches UK calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_us_calculator(): + """Fixture that patches US calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_calculators(): + """Fixture that patches both UK and US calculators.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as uk_mock, + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as us_mock, + ): + yield {"uk": uk_mock, "us": us_mock} + + +@pytest.fixture +def mock_failing_calculator(): + """Fixture that patches calculators to fail.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_household_failing, + ), + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_household_failing, + ), + ): + yield + + +# ============================================================================= +# Database Factory Functions +# ============================================================================= + + +def create_tax_benefit_model( + session: Session, + name: str = "policyengine-uk", + description: str = "UK tax benefit model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + model = TaxBenefitModel( + name=name, + description=description, + ) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_model_version( + session: Session, + model_id: UUID, + version: str = "1.0.0", + description: str = "Test version", +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion record.""" + model_version = TaxBenefitModelVersion( + model_id=model_id, + version=version, + description=description, + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_parameter( + session: Session, + model_version_id: UUID, + name: str = "test_parameter", + label: str = "Test Parameter", + description: str = "A test parameter", +) -> Parameter: + """Create and persist a Parameter record.""" + param = Parameter( + tax_benefit_model_version_id=model_version_id, + name=name, + label=label, + description=description, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_policy( + session: Session, + model_version_id: UUID, + name: str = "Test Policy", + description: str = "A test policy", +) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + tax_benefit_model_version_id=model_version_id, + name=name, + description=description, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_policy_with_parameter_value( + session: Session, + model_version_id: UUID, + parameter_id: UUID, + value: float, + name: str = "Test Policy", +) -> Policy: + """Create a Policy with an associated ParameterValue.""" + policy = create_policy(session, model_version_id, name=name) + + param_value = ParameterValue( + policy_id=policy.id, + parameter_id=parameter_id, + value_json={"value": value}, + ) + session.add(param_value) + session.commit() + session.refresh(policy) + return policy + + +def create_household_for_analysis( + session: Session, + tax_benefit_model_name: str = "policyengine_uk", + year: int = 2024, + label: str = "Test household for analysis", +) -> Household: + """Create a household suitable for analysis testing.""" + if tax_benefit_model_name == "policyengine_uk": + household_data = { + "people": [{"age": 30, "employment_income": 35000}], + "benunit": {}, + "household": {"region": "LONDON"}, + } + else: + household_data = { + "people": [{"age": 30, "employment_income": 50000}], + "tax_unit": {"state_code": "CA"}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_fips": 6}, + } + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def setup_uk_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create UK model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_model_version(session, model.id) + return model, version + + +def setup_us_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create US model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_model_version(session, model.id) + return model, version diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py index e8a614b..23465c7 100644 --- a/tests/test_analysis_household_impact.py +++ b/tests/test_analysis_household_impact.py @@ -1,18 +1,247 @@ """Tests for household impact analysis endpoints.""" -from uuid import uuid4 +from datetime import date +from uuid import UUID, uuid4 import pytest -from test_fixtures.fixtures_analysis import ( +from test_fixtures.fixtures_household_analysis import ( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, create_household_for_analysis, create_policy, setup_uk_model_and_version, setup_us_model_and_version, ) +from policyengine_api.api.household_analysis import ( + UK_CONFIG, + US_CONFIG, + _ensure_list, + _extract_value, + _format_date, + compute_entity_diff, + compute_entity_list_diff, + compute_household_impact, + compute_variable_diff, + get_calculator, + get_country_config, +) from policyengine_api.models import Report, ReportStatus, Simulation, SimulationType +# --------------------------------------------------------------------------- +# Unit tests for helper functions +# --------------------------------------------------------------------------- + + +class TestEnsureList: + """Tests for _ensure_list helper.""" + + def test_none_returns_empty_list(self): + assert _ensure_list(None) == [] + + def test_list_returns_same_list(self): + input_list = [1, 2, 3] + assert _ensure_list(input_list) == input_list + + def test_dict_wrapped_in_list(self): + input_dict = {"key": "value"} + result = _ensure_list(input_dict) + assert result == [input_dict] + + def test_empty_list_returns_empty_list(self): + assert _ensure_list([]) == [] + + +class TestExtractValue: + """Tests for _extract_value helper.""" + + def test_dict_with_value_key(self): + assert _extract_value({"value": 100}) == 100 + + def test_dict_without_value_key(self): + assert _extract_value({"other": 100}) is None + + def test_non_dict_returns_as_is(self): + assert _extract_value(100) == 100 + assert _extract_value("string") == "string" + assert _extract_value([1, 2]) == [1, 2] + + +class TestFormatDate: + """Tests for _format_date helper.""" + + def test_none_returns_none(self): + assert _format_date(None) is None + + def test_date_object_formatted(self): + d = date(2024, 1, 15) + assert _format_date(d) == "2024-01-15" + + def test_string_returns_string(self): + assert _format_date("2024-01-15") == "2024-01-15" + + +class TestComputeVariableDiff: + """Tests for compute_variable_diff helper.""" + + def test_numeric_values_return_diff(self): + result = compute_variable_diff(100, 150) + assert result == {"baseline": 100, "reform": 150, "change": 50} + + def test_negative_change(self): + result = compute_variable_diff(150, 100) + assert result == {"baseline": 150, "reform": 100, "change": -50} + + def test_float_values(self): + result = compute_variable_diff(100.5, 200.5) + assert result == {"baseline": 100.5, "reform": 200.5, "change": 100.0} + + def test_non_numeric_baseline_returns_none(self): + assert compute_variable_diff("string", 100) is None + + def test_non_numeric_reform_returns_none(self): + assert compute_variable_diff(100, "string") is None + + def test_both_non_numeric_returns_none(self): + assert compute_variable_diff("a", "b") is None + + +class TestComputeEntityDiff: + """Tests for compute_entity_diff helper.""" + + def test_computes_diff_for_numeric_keys(self): + baseline = {"income": 1000, "tax": 200, "name": "John"} + reform = {"income": 1000, "tax": 150, "name": "John"} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert result["income"]["change"] == 0 + assert "tax" in result + assert result["tax"]["change"] == -50 + assert "name" not in result + + def test_missing_key_in_reform_skipped(self): + baseline = {"income": 1000, "tax": 200} + reform = {"income": 1000} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert "tax" not in result + + def test_empty_entities(self): + assert compute_entity_diff({}, {}) == {} + + +class TestComputeEntityListDiff: + """Tests for compute_entity_list_diff helper.""" + + def test_computes_diff_for_each_pair(self): + baseline_list = [{"income": 100}, {"income": 200}] + reform_list = [{"income": 120}, {"income": 180}] + result = compute_entity_list_diff(baseline_list, reform_list) + + assert len(result) == 2 + assert result[0]["income"]["change"] == 20 + assert result[1]["income"]["change"] == -20 + + def test_empty_lists(self): + assert compute_entity_list_diff([], []) == [] + + +class TestComputeHouseholdImpact: + """Tests for compute_household_impact helper.""" + + def test_uk_household_impact(self): + result = compute_household_impact( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + UK_CONFIG, + ) + + assert "person" in result + assert "benunit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert "income_tax" in person_diff + assert person_diff["income_tax"]["baseline"] == 4500.0 + assert person_diff["income_tax"]["reform"] == 4000.0 + assert person_diff["income_tax"]["change"] == -500.0 + + def test_us_household_impact(self): + result = compute_household_impact( + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + US_CONFIG, + ) + + assert "person" in result + assert "tax_unit" in result + assert "spm_unit" in result + assert "family" in result + assert "marital_unit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert person_diff["income_tax"]["change"] == -500.0 + + def test_missing_entity_skipped(self): + baseline = {"person": [{"income": 100}]} + reform = {"person": [{"income": 120}]} + result = compute_household_impact(baseline, reform, UK_CONFIG) + + assert "person" in result + assert "benunit" not in result + assert "household" not in result + + +class TestGetCountryConfig: + """Tests for get_country_config helper.""" + + def test_uk_model_returns_uk_config(self): + config = get_country_config("policyengine_uk") + assert config == UK_CONFIG + assert config.name == "uk" + assert "benunit" in config.entity_types + + def test_us_model_returns_us_config(self): + config = get_country_config("policyengine_us") + assert config == US_CONFIG + assert config.name == "us" + assert "tax_unit" in config.entity_types + + def test_unknown_model_defaults_to_us(self): + config = get_country_config("unknown_model") + assert config == US_CONFIG + + +class TestGetCalculator: + """Tests for get_calculator helper.""" + + def test_uk_model_returns_uk_calculator(self): + from policyengine_api.api.household_analysis import calculate_uk_household + + calc = get_calculator("policyengine_uk") + assert calc == calculate_uk_household + + def test_us_model_returns_us_calculator(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("policyengine_us") + assert calc == calculate_us_household + + def test_unknown_model_defaults_to_us(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("unknown_model") + assert calc == calculate_us_household + + # --------------------------------------------------------------------------- # Validation tests (no database required beyond session fixture) # --------------------------------------------------------------------------- @@ -142,8 +371,8 @@ def test_simulation_type_is_household(self, client, session): ) data = response.json() - # Check simulation in database - sim_id = data["baseline_simulation"]["id"] + # Check simulation in database (convert string to UUID for query) + sim_id = UUID(data["baseline_simulation"]["id"]) sim = session.get(Simulation, sim_id) assert sim is not None assert sim.simulation_type == SimulationType.HOUSEHOLD @@ -165,11 +394,11 @@ def test_report_links_simulations(self, client, session): ) data = response.json() - # Check report in database - report = session.get(Report, data["report_id"]) + # Check report in database (convert string to UUID for query) + report = session.get(Report, UUID(data["report_id"])) assert report is not None - assert report.baseline_simulation_id == data["baseline_simulation"]["id"] - assert report.reform_simulation_id == data["reform_simulation"]["id"] + assert report.baseline_simulation_id == UUID(data["baseline_simulation"]["id"]) + assert report.reform_simulation_id == UUID(data["reform_simulation"]["id"]) assert report.report_type == "household_comparison" From 6f90fbef77b57c5b669fff0ec3517b4c8760ed6b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Feb 2026 03:12:48 +0300 Subject: [PATCH 06/19] feat: Use Alembic for db migrations --- .claude/skills/database-migrations.md | 301 +++++++++ CLAUDE.md | 16 +- alembic.ini | 145 +++++ alembic/README | 1 + alembic/env.py | 87 +++ alembic/script.py.mako | 28 + ...7ac554f4aa_add_parameter_values_indexes.py | 52 ++ .../20260204_d6e30d3b834d_initial_schema.py | 599 ++++++++++++++++++ pyproject.toml | 1 + scripts/init.py | 155 +++-- scripts/seed_nevada.py | 128 ++++ src/policyengine_api/config/settings.py | 17 +- supabase/.temp/cli-latest | 2 +- ...229000000_add_parameter_values_indexes.sql | 0 .../20260103000000_add_poverty_inequality.sql | 0 .../20260111000000_add_aggregate_status.sql | 0 .../20260203000000_create_households.sql | 0 ...001_create_user_household_associations.sql | 0 ...203000002_simulation_household_support.sql | 0 uv.lock | 28 + 20 files changed, 1514 insertions(+), 46 deletions(-) create mode 100644 .claude/skills/database-migrations.md create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py create mode 100644 alembic/versions/20260204_d6e30d3b834d_initial_schema.py create mode 100644 scripts/seed_nevada.py rename supabase/{migrations => migrations_archived}/20251229000000_add_parameter_values_indexes.sql (100%) rename supabase/{migrations => migrations_archived}/20260103000000_add_poverty_inequality.sql (100%) rename supabase/{migrations => migrations_archived}/20260111000000_add_aggregate_status.sql (100%) rename supabase/{migrations => migrations_archived}/20260203000000_create_households.sql (100%) rename supabase/{migrations => migrations_archived}/20260203000001_create_user_household_associations.sql (100%) rename supabase/{migrations => migrations_archived}/20260203000002_simulation_household_support.sql (100%) diff --git a/.claude/skills/database-migrations.md b/.claude/skills/database-migrations.md new file mode 100644 index 0000000..fedbef8 --- /dev/null +++ b/.claude/skills/database-migrations.md @@ -0,0 +1,301 @@ +# Database Migration Guidelines + +## Overview + +This project uses **Alembic** for database migrations with **SQLModel** models. Alembic is the industry-standard migration tool for SQLAlchemy/SQLModel projects. + +**CRITICAL**: SQL migrations are the single source of truth for database schema. All table creation and schema changes MUST go through Alembic migrations. + +## Architecture + +``` +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ SQLModel Models (src/policyengine_api/models/) │ +│ - Define Python classes │ +│ - Used for ORM queries │ +│ - NOT the source of truth for schema │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + │ alembic revision --autogenerate + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ Alembic Migrations (alembic/versions/) │ +│ - Create/alter tables │ +│ - Add indexes, constraints │ +│ - SOURCE OF TRUTH for schema │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + │ + │ alembic upgrade head + ā–¼ +ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā” +│ PostgreSQL Database (Supabase) │ +│ - Actual schema │ +│ - Tracked by alembic_version table │ +ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ +``` + +## Essential Rules + +### 1. NEVER use SQLModel.metadata.create_all() for schema creation + +The old pattern of using `SQLModel.metadata.create_all()` is deprecated. All tables are created via Alembic migrations. + +### 2. Every schema change requires a migration + +When you modify a SQLModel model (add column, change type, add index), you MUST: +1. Update the model in `src/policyengine_api/models/` +2. Generate a migration: `uv run alembic revision --autogenerate -m "Description"` +3. **Read and verify the generated migration** (see below) +4. Apply it: `uv run alembic upgrade head` + +### 3. ALWAYS verify auto-generated migrations before applying + +**This is critical for AI agents.** After running `alembic revision --autogenerate`, you MUST: + +1. **Read the generated migration file** in `alembic/versions/` +2. **Verify the `upgrade()` function** contains the expected changes: + - Correct table/column names + - Correct column types (e.g., `sa.String()`, `sa.Uuid()`, `sa.Integer()`) + - Proper foreign key references + - Appropriate nullable settings +3. **Verify the `downgrade()` function** properly reverses the changes +4. **Check for Alembic autogenerate limitations:** + - It may miss renamed columns (shows as drop + add instead) + - It may not detect some index changes + - It doesn't handle data migrations +5. **Edit the migration if needed** before applying + +Example verification: +```python +# Generated migration - verify this looks correct: +def upgrade() -> None: + op.add_column('users', sa.Column('phone', sa.String(), nullable=True)) + +def downgrade() -> None: + op.drop_column('users', 'phone') +``` + +**Never blindly apply a migration without reading it first.** + +### 4. Migrations must be self-contained + +Each migration should: +- Create tables it needs (never assume they exist from Python) +- Include both `upgrade()` and `downgrade()` functions +- Be idempotent where possible (use `IF NOT EXISTS` patterns) + +### 5. Never use conditional logic based on table existence + +Migrations should NOT check if tables exist. Instead: +- Ensure migrations run in the correct order (use `down_revision`) +- The initial migration creates all base tables +- Subsequent migrations build on that foundation + +## Common Commands + +```bash +# Apply all pending migrations +uv run alembic upgrade head + +# Generate migration from model changes +uv run alembic revision --autogenerate -m "Add users email index" + +# Create empty migration (for manual SQL) +uv run alembic revision -m "Add custom index" + +# Check current migration state +uv run alembic current + +# Show migration history +uv run alembic history + +# Downgrade one revision +uv run alembic downgrade -1 + +# Downgrade to specific revision +uv run alembic downgrade +``` + +## Local Development Workflow + +```bash +# 1. Start Supabase +supabase start + +# 2. Initialize database (runs migrations + applies RLS policies) +uv run python scripts/init.py + +# 3. Seed data +uv run python scripts/seed.py +``` + +### Reset database (DESTRUCTIVE) + +```bash +uv run python scripts/init.py --reset +``` + +## Adding a New Model + +1. Create the model in `src/policyengine_api/models/` + +```python +# src/policyengine_api/models/my_model.py +from sqlmodel import SQLModel, Field +from uuid import UUID, uuid4 + +class MyModel(SQLModel, table=True): + __tablename__ = "my_models" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str +``` + +2. Export in `__init__.py`: + +```python +# src/policyengine_api/models/__init__.py +from .my_model import MyModel +``` + +3. Generate migration: + +```bash +uv run alembic revision --autogenerate -m "Add my_models table" +``` + +4. Review the generated migration in `alembic/versions/` + +5. Apply the migration: + +```bash +uv run alembic upgrade head +``` + +6. Update `scripts/init.py` to include the table in RLS policies if needed. + +## Adding an Index + +1. Generate a migration: + +```bash +uv run alembic revision -m "Add index on users.email" +``` + +2. Edit the migration: + +```python +def upgrade() -> None: + op.create_index("idx_users_email", "users", ["email"]) + +def downgrade() -> None: + op.drop_index("idx_users_email", "users") +``` + +3. Apply: + +```bash +uv run alembic upgrade head +``` + +## Production Considerations + +### Applying migrations to production + +1. Migrations are automatically applied when deploying +2. Always test migrations locally first +3. For data migrations, consider running during low-traffic periods + +### Transitioning production from old system to Alembic + +Production databases that were created before Alembic (using the old `SQLModel.metadata.create_all()` approach or raw Supabase migrations) need special handling. Running `alembic upgrade head` would fail because the tables already exist. + +**The solution: `alembic stamp`** + +The `alembic stamp` command marks a migration as "already applied" without actually running it. This tells Alembic "the database is already at this state, start tracking from here." + +**How it works:** + +1. `alembic stamp ` inserts a row into the `alembic_version` table with the specified revision ID +2. Alembic now thinks that migration (and all migrations before it) have been applied +3. Future migrations will run normally starting from that point + +**Step-by-step production transition:** + +```bash +# 1. Connect to production database +# (set SUPABASE_DB_URL or other connection env vars) + +# 2. Check if alembic_version table exists +# If not, Alembic will create it automatically + +# 3. Verify production schema matches the initial migration +# Compare tables/columns in production against alembic/versions/20260204_d6e30d3b834d_initial_schema.py + +# 4. Stamp the initial migration as applied +uv run alembic stamp d6e30d3b834d + +# 5. If production also has the indexes from the second migration, stamp that too +uv run alembic stamp a17ac554f4aa + +# 6. Verify the stamp worked +uv run alembic current +# Should show: a17ac554f4aa (head) + +# 7. From now on, new migrations will apply normally +uv run alembic upgrade head +``` + +**Handling partially applied migrations:** + +If production has some but not all changes from a migration: + +1. Manually apply the missing changes via SQL +2. Then stamp that migration as complete +3. Or: create a new migration that only adds the missing pieces + +**After stamping:** + +- All future schema changes go through Alembic migrations +- Developers generate migrations with `alembic revision --autogenerate` +- Deployments run `alembic upgrade head` to apply pending migrations +- The `alembic_version` table tracks what's been applied + +## File Structure + +``` +alembic/ +ā”œā”€ā”€ env.py # Alembic configuration (imports models, sets DB URL) +ā”œā”€ā”€ script.py.mako # Template for new migrations +ā”œā”€ā”€ versions/ # Migration files +│ ā”œā”€ā”€ 20260204_d6e30d3b834d_initial_schema.py +│ └── 20260204_a17ac554f4aa_add_parameter_values_indexes.py +alembic.ini # Alembic settings + +supabase/ +ā”œā”€ā”€ migrations/ # Supabase-specific migrations (storage only) +│ ā”œā”€ā”€ 20241119000000_storage_bucket.sql +│ └── 20241121000000_storage_policies.sql +└── migrations_archived/ # Old table migrations (now in Alembic) +``` + +## Troubleshooting + +### "Target database is not up to date" + +Run `alembic upgrade head` to apply pending migrations. + +### "Can't locate revision" + +The alembic_version table has a revision that doesn't exist in your migrations folder. This can happen if someone deleted a migration file. Fix by stamping to a known revision: + +```bash +alembic stamp head # If tables are current +alembic stamp d6e30d3b834d # If at initial schema +``` + +### "Table already exists" + +The migration is trying to create a table that already exists. Options: +1. If this is a fresh setup, drop and recreate: `uv run python scripts/init.py --reset` +2. If in production, stamp the migration as applied: `alembic stamp ` diff --git a/CLAUDE.md b/CLAUDE.md index 2df55fc..d6fb240 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,7 +75,21 @@ Use `gh` CLI for GitHub operations to ensure Actions run correctly. ## Database -`make init` resets tables and storage. `make seed` populates UK/US models with variables, parameters, and datasets. +This project uses **Alembic** for database migrations. See `.claude/skills/database-migrations.md` for detailed guidelines. + +**Key rules:** +- All schema changes go through Alembic migrations (never use `SQLModel.metadata.create_all()`) +- After modifying a model: `uv run alembic revision --autogenerate -m "Description"` +- Apply migrations: `uv run alembic upgrade head` + +**Local development:** +```bash +supabase start # Start local Supabase +uv run python scripts/init.py # Run migrations + apply RLS policies +uv run python scripts/seed.py # Seed data +``` + +`scripts/init.py --reset` drops and recreates everything (destructive). ## Modal sandbox + Claude Code CLI gotchas diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..ed54635 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,145 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names +# Prepend with date for easier chronological ordering +file_template = %%(year)d%%(month).2d%%(day).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL - This is overridden by env.py which reads from application settings. +# The placeholder below is only used if env.py doesn't set it. +sqlalchemy.url = postgresql://placeholder:placeholder@localhost/placeholder + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# NOTE: ruff is in dev dependencies, so this hook only works when dev deps are installed +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..f930498 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,87 @@ +"""Alembic environment configuration for SQLModel migrations. + +This module configures Alembic to: +1. Use the database URL from application settings +2. Import all SQLModel models for autogenerate support +3. Run migrations in both offline and online modes +""" + +import sys +from logging.config import fileConfig +from pathlib import Path + +from sqlalchemy import engine_from_config, pool +from sqlmodel import SQLModel + +from alembic import context + +# Add src to path so we can import policyengine_api +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Import all models to register them with SQLModel.metadata +# This is required for autogenerate to detect model changes +from policyengine_api import models # noqa: F401 +from policyengine_api.config.settings import settings + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Override sqlalchemy.url with the actual database URL from settings +config.set_main_option("sqlalchemy.url", settings.database_url) + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# SQLModel metadata for autogenerate support +# This allows Alembic to detect changes in your SQLModel models +target_metadata = SQLModel.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py b/alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py new file mode 100644 index 0000000..e1967c2 --- /dev/null +++ b/alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py @@ -0,0 +1,52 @@ +"""Add parameter_values indexes + +Revision ID: a17ac554f4aa +Revises: d6e30d3b834d +Create Date: 2026-02-04 02:20:00.000000 + +This migration adds performance indexes to the parameter_values table +for optimizing common query patterns. +""" + +from typing import Sequence, Union + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "a17ac554f4aa" +down_revision: Union[str, Sequence[str], None] = "d6e30d3b834d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add performance indexes to parameter_values.""" + # Composite index for the most common query pattern (filtering by both) + op.create_index( + "idx_parameter_values_parameter_policy", + "parameter_values", + ["parameter_id", "policy_id"], + ) + + # Single index on policy_id for filtering by policy alone + op.create_index( + "idx_parameter_values_policy", + "parameter_values", + ["policy_id"], + ) + + # Partial index for baseline values (policy_id IS NULL) + # This optimizes the common "get current law values" query + op.create_index( + "idx_parameter_values_baseline", + "parameter_values", + ["parameter_id"], + postgresql_where="policy_id IS NULL", + ) + + +def downgrade() -> None: + """Remove parameter_values indexes.""" + op.drop_index("idx_parameter_values_baseline", "parameter_values") + op.drop_index("idx_parameter_values_policy", "parameter_values") + op.drop_index("idx_parameter_values_parameter_policy", "parameter_values") diff --git a/alembic/versions/20260204_d6e30d3b834d_initial_schema.py b/alembic/versions/20260204_d6e30d3b834d_initial_schema.py new file mode 100644 index 0000000..d4de071 --- /dev/null +++ b/alembic/versions/20260204_d6e30d3b834d_initial_schema.py @@ -0,0 +1,599 @@ +"""Initial schema + +Revision ID: d6e30d3b834d +Revises: +Create Date: 2026-02-04 02:15:03.471607 + +This migration creates all base tables for the PolicyEngine API. +Tables are organized by dependency tier to ensure proper creation order. +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "d6e30d3b834d" +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create all tables.""" + # ======================================================================== + # TIER 1: Tables with no foreign key dependencies + # ======================================================================== + + # Tax benefit models (e.g., "uk", "us") + op.create_table( + "tax_benefit_models", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + # Users + op.create_table( + "users", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("first_name", sa.String(), nullable=False), + sa.Column("last_name", sa.String(), nullable=False), + sa.Column("email", sa.String(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + ) + op.create_index("ix_users_email", "users", ["email"]) + + # Policies (reform definitions) + op.create_table( + "policies", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + # Dynamics (behavioral response definitions) + op.create_table( + "dynamics", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + + # ======================================================================== + # TIER 2: Tables depending on tier 1 + # ======================================================================== + + # Tax benefit model versions + op.create_table( + "tax_benefit_model_versions", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("model_id", sa.Uuid(), nullable=False), + sa.Column("version", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["model_id"], ["tax_benefit_models.id"]), + ) + + # Datasets (h5 files in storage) + op.create_table( + "datasets", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("filepath", sa.String(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("is_output_dataset", sa.Boolean(), nullable=False, default=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), + ) + + # ======================================================================== + # TIER 3: Tables depending on tier 2 + # ======================================================================== + + # Parameters (tax-benefit system parameters) + op.create_table( + "parameters", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("label", sa.String(), nullable=True), + sa.Column("description", sa.String(), nullable=True), + sa.Column("data_type", sa.String(), nullable=True), + sa.Column("unit", sa.String(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] + ), + ) + + # Variables (tax-benefit system variables) + op.create_table( + "variables", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("data_type", sa.String(), nullable=True), + sa.Column("possible_values", sa.JSON(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] + ), + ) + + # Dataset versions + op.create_table( + "dataset_versions", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=False), + sa.Column("dataset_id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), + sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), + ) + + # Households (stored household definitions) + op.create_table( + "households", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_name", sa.String(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("label", sa.String(), nullable=True), + sa.Column("household_data", sa.JSON(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "idx_households_model_name", "households", ["tax_benefit_model_name"] + ) + op.create_index("idx_households_year", "households", ["year"]) + + # ======================================================================== + # TIER 4: Tables depending on tier 3 + # ======================================================================== + + # Parameter values (policy/dynamic parameter modifications) + op.create_table( + "parameter_values", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("parameter_id", sa.Uuid(), nullable=False), + sa.Column("value_json", sa.JSON(), nullable=True), + sa.Column("start_date", sa.DateTime(timezone=True), nullable=False), + sa.Column("end_date", sa.DateTime(timezone=True), nullable=True), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["parameter_id"], ["parameters.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), + ) + + # Simulations (economy or household calculations) + op.create_table( + "simulations", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_type", sa.String(), nullable=False, default="economy"), + sa.Column("dataset_id", sa.Uuid(), nullable=True), + sa.Column("household_id", sa.Uuid(), nullable=True), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), + sa.Column("output_dataset_id", sa.Uuid(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("household_result", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), + sa.ForeignKeyConstraint(["household_id"], ["households.id"]), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), + sa.ForeignKeyConstraint( + ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] + ), + sa.ForeignKeyConstraint(["output_dataset_id"], ["datasets.id"]), + ) + + # User-household associations + op.create_table( + "user_household_associations", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("household_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["household_id"], ["households.id"], ondelete="CASCADE"), + sa.UniqueConstraint("user_id", "household_id"), + ) + op.create_index( + "idx_user_household_user", "user_household_associations", ["user_id"] + ) + op.create_index( + "idx_user_household_household", "user_household_associations", ["household_id"] + ) + + # Household jobs (async household calculations) + op.create_table( + "household_jobs", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_name", sa.String(), nullable=False), + sa.Column("request_data", sa.JSON(), nullable=False), + sa.Column("policy_id", sa.Uuid(), nullable=True), + sa.Column("dynamic_id", sa.Uuid(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("result", sa.JSON(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), + sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), + ) + + # ======================================================================== + # TIER 5: Tables depending on simulations + # ======================================================================== + + # Reports (analysis reports) + op.create_table( + "reports", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("label", sa.String(), nullable=False), + sa.Column("description", sa.String(), nullable=True), + sa.Column("report_type", sa.String(), nullable=True), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("markdown", sa.Text(), nullable=True), + sa.Column("parent_report_id", sa.Uuid(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=True), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["parent_report_id"], ["reports.id"]), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + ) + + # Aggregates (single-simulation aggregate outputs) + op.create_table( + "aggregates", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable", sa.String(), nullable=False), + sa.Column("aggregate_type", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=True), + sa.Column("filter_config", sa.JSON(), nullable=False, default={}), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("result", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Change aggregates (baseline vs reform comparison) + op.create_table( + "change_aggregates", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=True), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("variable", sa.String(), nullable=False), + sa.Column("aggregate_type", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=True), + sa.Column("filter_config", sa.JSON(), nullable=False, default={}), + sa.Column("change_geq", sa.Float(), nullable=True), + sa.Column("change_leq", sa.Float(), nullable=True), + sa.Column("status", sa.String(), nullable=False, default="pending"), + sa.Column("error_message", sa.String(), nullable=True), + sa.Column("result", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["user_id"], ["users.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Decile impacts + op.create_table( + "decile_impacts", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("income_variable", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=True), + sa.Column("decile", sa.Integer(), nullable=False), + sa.Column("quantiles", sa.Integer(), nullable=False, default=10), + sa.Column("baseline_mean", sa.Float(), nullable=True), + sa.Column("reform_mean", sa.Float(), nullable=True), + sa.Column("absolute_change", sa.Float(), nullable=True), + sa.Column("relative_change", sa.Float(), nullable=True), + sa.Column("count_better_off", sa.Float(), nullable=True), + sa.Column("count_worse_off", sa.Float(), nullable=True), + sa.Column("count_no_change", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Program statistics + op.create_table( + "program_statistics", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), + sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("program_name", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False), + sa.Column("is_tax", sa.Boolean(), nullable=False, default=False), + sa.Column("baseline_total", sa.Float(), nullable=True), + sa.Column("reform_total", sa.Float(), nullable=True), + sa.Column("change", sa.Float(), nullable=True), + sa.Column("baseline_count", sa.Float(), nullable=True), + sa.Column("reform_count", sa.Float(), nullable=True), + sa.Column("winners", sa.Float(), nullable=True), + sa.Column("losers", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), + ) + + # Poverty + op.create_table( + "poverty", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("poverty_type", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False, default="person"), + sa.Column("filter_variable", sa.String(), nullable=True), + sa.Column("headcount", sa.Float(), nullable=True), + sa.Column("total_population", sa.Float(), nullable=True), + sa.Column("rate", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["simulation_id"], ["simulations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"], ondelete="CASCADE"), + ) + op.create_index("idx_poverty_simulation_id", "poverty", ["simulation_id"]) + op.create_index("idx_poverty_report_id", "poverty", ["report_id"]) + + # Inequality + op.create_table( + "inequality", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("simulation_id", sa.Uuid(), nullable=False), + sa.Column("report_id", sa.Uuid(), nullable=True), + sa.Column("income_variable", sa.String(), nullable=False), + sa.Column("entity", sa.String(), nullable=False, default="household"), + sa.Column("gini", sa.Float(), nullable=True), + sa.Column("top_10_share", sa.Float(), nullable=True), + sa.Column("top_1_share", sa.Float(), nullable=True), + sa.Column("bottom_50_share", sa.Float(), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint( + ["simulation_id"], ["simulations.id"], ondelete="CASCADE" + ), + sa.ForeignKeyConstraint(["report_id"], ["reports.id"], ondelete="CASCADE"), + ) + op.create_index("idx_inequality_simulation_id", "inequality", ["simulation_id"]) + op.create_index("idx_inequality_report_id", "inequality", ["report_id"]) + + +def downgrade() -> None: + """Drop all tables in reverse order.""" + # Tier 5 + op.drop_index("idx_inequality_report_id", "inequality") + op.drop_index("idx_inequality_simulation_id", "inequality") + op.drop_table("inequality") + op.drop_index("idx_poverty_report_id", "poverty") + op.drop_index("idx_poverty_simulation_id", "poverty") + op.drop_table("poverty") + op.drop_table("program_statistics") + op.drop_table("decile_impacts") + op.drop_table("change_aggregates") + op.drop_table("aggregates") + op.drop_table("reports") + + # Tier 4 + op.drop_table("household_jobs") + op.drop_index("idx_user_household_household", "user_household_associations") + op.drop_index("idx_user_household_user", "user_household_associations") + op.drop_table("user_household_associations") + op.drop_table("simulations") + op.drop_table("parameter_values") + + # Tier 3 + op.drop_index("idx_households_year", "households") + op.drop_index("idx_households_model_name", "households") + op.drop_table("households") + op.drop_table("dataset_versions") + op.drop_table("variables") + op.drop_table("parameters") + + # Tier 2 + op.drop_table("datasets") + op.drop_table("tax_benefit_model_versions") + + # Tier 1 + op.drop_table("dynamics") + op.drop_table("policies") + op.drop_index("ix_users_email", "users") + op.drop_table("users") + op.drop_table("tax_benefit_models") diff --git a/pyproject.toml b/pyproject.toml index 27eb310..1fe9093 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "fastapi-mcp>=0.4.0", "modal>=0.68.0", "anthropic>=0.40.0", + "alembic>=1.13.0", ] [project.optional-dependencies] diff --git a/scripts/init.py b/scripts/init.py index cf7a04a..3aa925b 100644 --- a/scripts/init.py +++ b/scripts/init.py @@ -1,12 +1,19 @@ -"""Initialise Supabase: reset database, recreate tables, buckets, and permissions. +"""Initialise Supabase database with tables, buckets, and permissions. -This script performs a complete reset of the Supabase instance: -1. Drops and recreates the public schema (all tables) -2. Deletes and recreates the storage bucket -3. Creates all tables from SQLModel definitions -4. Applies RLS policies and storage permissions +This script can run in two modes: +1. Init mode (default): Creates tables via Alembic, applies RLS policies +2. Reset mode (--reset): Drops everything and recreates from scratch (DESTRUCTIVE) + +Usage: + uv run python scripts/init.py # Safe init (creates if not exists) + uv run python scripts/init.py --reset # Destructive reset (drops everything) + +For local development after `supabase start`, use init mode. +For production, use init mode to ensure tables and policies exist. +Reset mode should only be used when you need a completely fresh database. """ +import subprocess import sys from pathlib import Path @@ -14,16 +21,14 @@ from rich.console import Console from rich.panel import Panel -from sqlmodel import SQLModel, create_engine +from sqlmodel import create_engine -# Import all models to register them with SQLModel.metadata -from policyengine_api import models # noqa: F401 from policyengine_api.config.settings import settings from policyengine_api.services.storage import get_service_role_client console = Console() -MIGRATIONS_DIR = Path(__file__).parent.parent / "supabase" / "migrations" +PROJECT_ROOT = Path(__file__).parent.parent def reset_storage_bucket(): @@ -57,30 +62,61 @@ def reset_storage_bucket(): console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") +def ensure_storage_bucket(): + """Ensure storage bucket exists (non-destructive).""" + console.print("[bold blue]Ensuring storage bucket exists...") + + try: + supabase = get_service_role_client() + bucket_name = settings.storage_bucket + + # Try to get bucket info + try: + supabase.storage.get_bucket(bucket_name) + console.print(f"[green]āœ“[/green] Bucket '{bucket_name}' exists") + except Exception: + # Bucket doesn't exist, create it + supabase.storage.create_bucket(bucket_name, options={"public": True}) + console.print(f"[green]āœ“[/green] Created bucket '{bucket_name}'") + + except Exception as e: + console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") + + def reset_database(): - """Drop and recreate all tables.""" - console.print("[bold blue]Resetting database...") + """Drop and recreate the public schema (DESTRUCTIVE).""" + console.print("[bold red]Dropping database schema...") engine = create_engine(settings.database_url, echo=False) - # Drop and recreate public schema - console.print(" Dropping public schema...") with engine.begin() as conn: conn.exec_driver_sql("DROP SCHEMA public CASCADE") conn.exec_driver_sql("CREATE SCHEMA public") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO postgres") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO public") - # Create all tables from SQLModel - console.print(" Creating tables...") - SQLModel.metadata.create_all(engine) + console.print("[green]āœ“[/green] Schema dropped and recreated") + return engine - tables = list(SQLModel.metadata.tables.keys()) - console.print(f"[green]āœ“[/green] Created {len(tables)} tables:") - for table in sorted(tables): - console.print(f" {table}") - return engine +def run_alembic_migrations(): + """Run Alembic migrations to create/update tables.""" + console.print("[bold blue]Running Alembic migrations...") + + result = subprocess.run( + ["uv", "run", "alembic", "upgrade", "head"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + console.print(f"[red]āœ— Alembic migration failed:[/red]") + console.print(result.stderr) + raise RuntimeError("Alembic migration failed") + + console.print("[green]āœ“[/green] Alembic migrations complete") + console.print(result.stdout) def apply_storage_policies(engine): @@ -158,6 +194,10 @@ def apply_rls_policies(engine): "parameter_values", "users", "household_jobs", + "households", + "user_household_associations", + "poverty", + "inequality", ] # Read-only tables (public can read, only service role can write) @@ -178,6 +218,7 @@ def apply_rls_policies(engine): "dynamics", "reports", "household_jobs", + "households", ] # Read-only results tables @@ -186,6 +227,8 @@ def apply_rls_policies(engine): "change_aggregates", "decile_impacts", "program_statistics", + "poverty", + "inequality", ] sql_parts = [] @@ -230,6 +273,13 @@ def apply_rls_policies(engine): FOR SELECT TO anon, authenticated USING (true); """) + # User-household associations need special handling + sql_parts.append(""" + DROP POLICY IF EXISTS "Users can manage own associations" ON user_household_associations; + CREATE POLICY "Users can manage own associations" ON user_household_associations + FOR ALL TO anon, authenticated USING (true) WITH CHECK (true); + """) + sql = "\n".join(sql_parts) conn = engine.raw_connection() @@ -246,30 +296,53 @@ def apply_rls_policies(engine): def main(): - """Run full Supabase initialisation.""" - console.print( - Panel.fit( - "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" - "This script resets the entire Supabase instance.", - title="Supabase init", + """Run Supabase initialisation.""" + reset_mode = "--reset" in sys.argv + + if reset_mode: + console.print( + Panel.fit( + "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" + "This script will reset the entire Supabase instance.", + title="Supabase RESET", + ) ) - ) - # Confirm unless running non-interactively - if sys.stdin.isatty(): - response = console.input("\nType 'yes' to continue: ") - if response.lower() != "yes": - console.print("[yellow]Aborted[/yellow]") - return + # Confirm unless running non-interactively + if sys.stdin.isatty(): + response = console.input("\nType 'yes' to continue: ") + if response.lower() != "yes": + console.print("[yellow]Aborted[/yellow]") + return + + console.print() + + # Reset storage bucket + reset_storage_bucket() + console.print() + + # Drop database schema + engine = reset_database() + console.print() + else: + console.print( + Panel.fit( + "[bold blue]Initialising Supabase[/bold blue]\n" + "This will create tables if they don't exist (safe/idempotent).\n" + "Use [cyan]--reset[/cyan] flag to drop and recreate everything.", + title="Supabase init", + ) + ) + console.print() - console.print() + # Ensure storage bucket exists + ensure_storage_bucket() + console.print() - # Reset storage bucket - reset_storage_bucket() - console.print() + engine = create_engine(settings.database_url, echo=False) - # Reset database and create tables - engine = reset_database() + # Run Alembic migrations to create/update tables + run_alembic_migrations() console.print() # Apply storage policies diff --git a/scripts/seed_nevada.py b/scripts/seed_nevada.py new file mode 100644 index 0000000..0af2cb4 --- /dev/null +++ b/scripts/seed_nevada.py @@ -0,0 +1,128 @@ +"""Seed Nevada datasets into local Supabase. + +This script seeds pre-created Nevada state and congressional district datasets +into the local Supabase database for testing purposes. + +Usage: + uv run python scripts/seed_nevada.py +""" + +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from rich.console import Console +from sqlmodel import Session, create_engine, select + +from policyengine_api.config.settings import settings +from policyengine_api.models import Dataset, TaxBenefitModel +from policyengine_api.services.storage import upload_dataset_for_seeding + +console = Console() + +# Nevada datasets location +NEVADA_DATA_DIR = Path(__file__).parent.parent / "test_data" / "nevada_datasets" + + +def main(): + """Seed Nevada datasets.""" + console.print("[bold blue]Seeding Nevada datasets for testing...") + + engine = create_engine(settings.database_url, echo=False) + + with Session(engine) as session: + # Get or create US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print(" Creating US tax-benefit model...") + us_model = TaxBenefitModel( + name="policyengine-us", + description="US tax-benefit system model", + ) + session.add(us_model) + session.commit() + session.refresh(us_model) + console.print(" [green]āœ“[/green] Created policyengine-us model") + + # Seed state datasets + states_dir = NEVADA_DATA_DIR / "states" + if states_dir.exists(): + console.print("\n [bold]Nevada State Datasets:[/bold]") + for h5_file in sorted(states_dir.glob("*.h5")): + name = h5_file.stem # e.g., "NV_year_2024" + year = int(name.split("_")[-1]) + + # Check if already exists + existing = session.exec( + select(Dataset).where(Dataset.name == name) + ).first() + + if existing: + console.print(f" [yellow]ā­[/yellow] {name} (already exists)") + continue + + # Upload to storage + console.print(f" Uploading {name}...", end=" ") + try: + object_name = upload_dataset_for_seeding(str(h5_file)) + + # Create database record + db_dataset = Dataset( + name=name, + description=f"Nevada state dataset for year {year}", + filepath=object_name, + year=year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + console.print("[green]āœ“[/green]") + except Exception as e: + console.print(f"[red]āœ— {e}[/red]") + + # Seed district datasets + districts_dir = NEVADA_DATA_DIR / "districts" + if districts_dir.exists(): + console.print("\n [bold]Nevada Congressional District Datasets:[/bold]") + for h5_file in sorted(districts_dir.glob("*.h5")): + name = h5_file.stem # e.g., "NV-01_year_2024" + year = int(name.split("_")[-1]) + district = name.split("_")[0] # e.g., "NV-01" + + # Check if already exists + existing = session.exec( + select(Dataset).where(Dataset.name == name) + ).first() + + if existing: + console.print(f" [yellow]ā­[/yellow] {name} (already exists)") + continue + + # Upload to storage + console.print(f" Uploading {name}...", end=" ") + try: + object_name = upload_dataset_for_seeding(str(h5_file)) + + # Create database record + db_dataset = Dataset( + name=name, + description=f"{district} congressional district dataset for year {year}", + filepath=object_name, + year=year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + console.print("[green]āœ“[/green]") + except Exception as e: + console.print(f"[red]āœ— {e}[/red]") + + console.print("\n[bold green]āœ“ Nevada datasets seeded successfully![/bold green]") + + +if __name__ == "__main__": + main() diff --git a/src/policyengine_api/config/settings.py b/src/policyengine_api/config/settings.py index 76a1ab1..efba345 100644 --- a/src/policyengine_api/config/settings.py +++ b/src/policyengine_api/config/settings.py @@ -40,10 +40,21 @@ class Settings(BaseSettings): @property def database_url(self) -> str: - """Get database URL from Supabase.""" + """Get database URL from Supabase. + + For local development, the database runs on port 54322 (not 54321 which is the API). + Use supabase_db_url to override, or rely on the default local URL. + """ + if self.supabase_db_url: + return self.supabase_db_url + + # For local development, default to the standard Supabase local DB port + if "localhost" in self.supabase_url or "127.0.0.1" in self.supabase_url: + return "postgresql://postgres:postgres@127.0.0.1:54322/postgres" + + # For remote Supabase, construct URL from API URL (usually need supabase_db_url set) return ( - self.supabase_db_url - or self.supabase_url.replace( + self.supabase_url.replace( "http://", "postgresql://postgres:postgres@" ).replace("https://", "postgresql://postgres:postgres@") + "/postgres" diff --git a/supabase/.temp/cli-latest b/supabase/.temp/cli-latest index 8c68db7..1dd6178 100644 --- a/supabase/.temp/cli-latest +++ b/supabase/.temp/cli-latest @@ -1 +1 @@ -v2.67.1 \ No newline at end of file +v2.75.0 \ No newline at end of file diff --git a/supabase/migrations/20251229000000_add_parameter_values_indexes.sql b/supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql similarity index 100% rename from supabase/migrations/20251229000000_add_parameter_values_indexes.sql rename to supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql diff --git a/supabase/migrations/20260103000000_add_poverty_inequality.sql b/supabase/migrations_archived/20260103000000_add_poverty_inequality.sql similarity index 100% rename from supabase/migrations/20260103000000_add_poverty_inequality.sql rename to supabase/migrations_archived/20260103000000_add_poverty_inequality.sql diff --git a/supabase/migrations/20260111000000_add_aggregate_status.sql b/supabase/migrations_archived/20260111000000_add_aggregate_status.sql similarity index 100% rename from supabase/migrations/20260111000000_add_aggregate_status.sql rename to supabase/migrations_archived/20260111000000_add_aggregate_status.sql diff --git a/supabase/migrations/20260203000000_create_households.sql b/supabase/migrations_archived/20260203000000_create_households.sql similarity index 100% rename from supabase/migrations/20260203000000_create_households.sql rename to supabase/migrations_archived/20260203000000_create_households.sql diff --git a/supabase/migrations/20260203000001_create_user_household_associations.sql b/supabase/migrations_archived/20260203000001_create_user_household_associations.sql similarity index 100% rename from supabase/migrations/20260203000001_create_user_household_associations.sql rename to supabase/migrations_archived/20260203000001_create_user_household_associations.sql diff --git a/supabase/migrations/20260203000002_simulation_household_support.sql b/supabase/migrations_archived/20260203000002_simulation_household_support.sql similarity index 100% rename from supabase/migrations/20260203000002_simulation_household_support.sql rename to supabase/migrations_archived/20260203000002_simulation_household_support.sql diff --git a/uv.lock b/uv.lock index 094ebf8..466caf4 100644 --- a/uv.lock +++ b/uv.lock @@ -91,6 +91,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alembic" +version = "1.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/41/ab8f624929847b49f84955c594b165855efd829b0c271e1a8cac694138e5/alembic-1.18.3.tar.gz", hash = "sha256:1212aa3778626f2b0f0aa6dd4e99a5f99b94bd25a0c1ac0bba3be65e081e50b0", size = 2052564, upload-time = "2026-01-29T20:24:15.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/8e/d79281f323e7469b060f15bd229e48d7cdd219559e67e71c013720a88340/alembic-1.18.3-py3-none-any.whl", hash = "sha256:12a0359bfc068a4ecbb9b3b02cf77856033abfdb59e4a5aca08b7eacd7b74ddd", size = 262282, upload-time = "2026-01-29T20:24:17.488Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1057,6 +1071,18 @@ sqlalchemy = [ { name = "opentelemetry-instrumentation-sqlalchemy" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -1757,6 +1783,7 @@ name = "policyengine-api-v2" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "alembic" }, { name = "anthropic" }, { name = "boto3" }, { name = "fastapi" }, @@ -1793,6 +1820,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.13.0" }, { name = "anthropic", specifier = ">=0.40.0" }, { name = "boto3", specifier = ">=1.41.1" }, { name = "fastapi", specifier = ">=0.115.0" }, From 0fb609b523fb7cddc5a0893b7ce60e61711571f2 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Feb 2026 03:53:48 +0300 Subject: [PATCH 07/19] refactor: Make household impact async pattern match economic impact MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace synchronous inline calculation with async trigger pattern - Add _trigger_household_impact() mirroring _trigger_economy_comparison() - Add _run_local_household_impact() for local execution (blocking) - Add _run_simulation_in_session() for running individual simulations - Update POST endpoint to trigger and return immediately - Add test script for manual end-to-end testing Note: Local execution blocks the request (same as economic impact). True async requires Modal functions (household_impact_uk/us). šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/test_household_impact.py | 135 +++++++++++++++ .../api/household_analysis.py | 161 ++++++++++++------ 2 files changed, 247 insertions(+), 49 deletions(-) create mode 100644 scripts/test_household_impact.py diff --git a/scripts/test_household_impact.py b/scripts/test_household_impact.py new file mode 100644 index 0000000..81c85b0 --- /dev/null +++ b/scripts/test_household_impact.py @@ -0,0 +1,135 @@ +"""Test household impact analysis end-to-end. + +This script tests the async household impact analysis workflow: +1. Create a stored household +2. Run household impact analysis (returns immediately with report_id) +3. Poll until completed +4. Verify results + +Usage: + uv run python scripts/test_household_impact.py +""" + +import sys +import time + +import requests + +BASE_URL = "http://127.0.0.1:8000" + + +def main(): + print("=" * 60) + print("Testing Household Impact Analysis (Async)") + print("=" * 60) + + # Step 1: Create a US household + print("\n1. Creating US household...") + household_data = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "Test household for impact analysis", + "people": [ + { + "age": 35, + "employment_income": 50000, + } + ], + "tax_unit": {}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_code": "NV"}, + } + + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + sys.exit(1) + + household = resp.json() + household_id = household["id"] + print(f" Created household: {household_id}") + + # Step 2: Run household impact analysis + print("\n2. Starting household impact analysis...") + impact_request = { + "household_id": household_id, + "policy_id": None, # Single run under current law + } + + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json=impact_request) + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + sys.exit(1) + + result = resp.json() + report_id = result["report_id"] + status = result["status"] + print(f" Report ID: {report_id}") + print(f" Initial status: {status}") + + # Step 3: Poll until completed + print("\n3. Polling for results...") + max_attempts = 30 + for attempt in range(max_attempts): + resp = requests.get(f"{BASE_URL}/analysis/household-impact/{report_id}") + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + sys.exit(1) + + result = resp.json() + status = result["status"].upper() # Normalize to uppercase + print(f" Attempt {attempt + 1}: status={status}") + + if status == "COMPLETED": + break + elif status == "FAILED": + print(f" FAILED: {result.get('error_message', 'Unknown error')}") + sys.exit(1) + + time.sleep(0.5) + else: + print(f" FAILED: Timed out after {max_attempts} attempts") + sys.exit(1) + + # Step 4: Verify results + print("\n4. Verifying results...") + baseline_result = result.get("baseline_result") + if not baseline_result: + print(" FAILED: No baseline result") + sys.exit(1) + + print(f" Baseline result keys: {list(baseline_result.keys())}") + + # Check for expected entity types + expected_entities = ["person", "tax_unit", "spm_unit", "family", "marital_unit", "household"] + for entity in expected_entities: + if entity in baseline_result: + print(f" āœ“ {entity}: {len(baseline_result[entity])} entities") + else: + print(f" āœ— {entity}: missing") + + # Look for net_income in person output + if "person" in baseline_result and baseline_result["person"]: + person = baseline_result["person"][0] + if "household_net_income" in person: + print(f" household_net_income: ${person['household_net_income']:,.2f}") + elif "spm_unit_net_income" in person: + print(f" spm_unit_net_income: ${person['spm_unit_net_income']:,.2f}") + + # Step 5: Cleanup - delete household + print("\n5. Cleaning up...") + resp = requests.delete(f"{BASE_URL}/households/{household_id}") + if resp.status_code == 204: + print(f" Deleted household: {household_id}") + else: + print(f" Warning: Failed to delete household: {resp.status_code}") + + print("\n" + "=" * 60) + print("SUCCESS: Household impact analysis working correctly!") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index 29ea89e..981d968 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -6,11 +6,11 @@ WORKFLOW: 1. Create a stored household: POST /households 2. Optionally create a reform policy: POST /policies -3. Run analysis: POST /analysis/household-impact -4. Results are synchronous - the response includes computed values +3. Run analysis: POST /analysis/household-impact (returns report_id) +4. Poll GET /analysis/household-impact/{report_id} until status="completed" +5. Results include baseline_result, reform_result (if comparison), and impact diff """ -from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime, timezone from typing import Any, Protocol @@ -18,6 +18,7 @@ import logfire from fastapi import APIRouter, Depends, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import BaseModel, Field from sqlmodel import Session @@ -38,6 +39,14 @@ _get_or_create_simulation, ) + +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + + router = APIRouter(prefix="/analysis", tags=["analysis"]) @@ -61,7 +70,14 @@ class CountryConfig: US_CONFIG = CountryConfig( name="us", - entity_types=("person", "tax_unit", "spm_unit", "family", "marital_unit", "household"), + entity_types=( + "person", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + "household", + ), ) @@ -326,68 +342,109 @@ def _load_policy_data(policy_id: UUID | None, session: Session) -> dict | None: # ============================================================================= -# Report Orchestration +# Report Orchestration (Async) # ============================================================================= -def trigger_household_report(report_id: UUID, session: Session) -> None: - """Trigger household simulation(s) for a report.""" - report = _load_report(report_id, session) - _mark_report_running(report, session) - - try: - _run_report_simulations(report, session) - _mark_report_completed(report, session) - except Exception as e: - _mark_report_failed(report, e, session) - +def _run_local_household_impact(report_id: str, session: Session) -> None: + """Run household impact analysis locally. -def _load_report(report_id: UUID, session: Session) -> Report: - """Load report or raise error.""" + NOTE: This runs synchronously and blocks the HTTP request when running + locally (agent_use_modal=False). This mirrors the economic impact behavior. + True async execution requires Modal. + """ report = session.get(Report, report_id) if not report: - raise ValueError(f"Report {report_id} not found") - return report - + return -def _mark_report_running(report: Report, session: Session) -> None: - """Mark report as running.""" report.status = ReportStatus.RUNNING session.add(report) session.commit() + try: + # Run baseline simulation + if report.baseline_simulation_id: + _run_simulation_in_session(report.baseline_simulation_id, session) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_simulation_in_session(report.reform_simulation_id, session) -def _mark_report_completed(report: Report, session: Session) -> None: - """Mark report as completed.""" - report.status = ReportStatus.COMPLETED - session.add(report) - session.commit() + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + except Exception as e: + report.status = ReportStatus.FAILED + report.error_message = str(e) + session.add(report) + session.commit() -def _mark_report_failed(report: Report, error: Exception, session: Session) -> None: - """Mark report as failed.""" - report.status = ReportStatus.FAILED - report.error_message = str(error) - session.add(report) +def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation within an existing session.""" + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + policy_data = _load_policy_data(simulation.policy_id, session) + + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) session.commit() + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) -def _run_report_simulations(report: Report, session: Session) -> None: - """Run all pending simulations for a report.""" - _run_simulation_if_pending(report.baseline_simulation_id, session) + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise - if report.reform_simulation_id: - _run_simulation_if_pending(report.reform_simulation_id, session) +def _trigger_household_impact( + report_id: str, tax_benefit_model_name: str, session: Session | None = None +) -> None: + """Trigger household impact calculation (local or Modal based on settings).""" + from policyengine_api.config import settings + + traceparent = get_traceparent() -def _run_simulation_if_pending(simulation_id: UUID | None, session: Session) -> None: - """Run simulation if it exists and is pending.""" - if not simulation_id: - return + if not settings.agent_use_modal and session is not None: + # Run locally (blocking - see _run_local_household_impact docstring) + _run_local_household_impact(report_id, session) + else: + # Use Modal + import modal - simulation = session.get(Simulation, simulation_id) - if simulation and simulation.status == SimulationStatus.PENDING: - run_household_simulation(simulation.id, session) + if tax_benefit_model_name == "policyengine_uk": + fn = modal.Function.from_name("policyengine", "household_impact_uk") + else: + fn = modal.Function.from_name("policyengine", "household_impact_us") + + fn.spawn(report_id=report_id, traceparent=traceparent) + + +# Legacy functions kept for compatibility +def _load_report(report_id: UUID, session: Session) -> Report: + """Load report or raise error.""" + report = session.get(Report, report_id) + if not report: + raise ValueError(f"Report {report_id} not found") + return report # ============================================================================= @@ -436,7 +493,9 @@ class HouseholdImpactResponse(BaseModel): # ============================================================================= -def build_simulation_info(simulation: Simulation | None) -> HouseholdSimulationInfo | None: +def build_simulation_info( + simulation: Simulation | None, +) -> HouseholdSimulationInfo | None: """Build simulation info from a simulation.""" if not simulation: return None @@ -540,7 +599,9 @@ def household_impact( If policy_id is None: single run under current law. If policy_id is set: comparison (baseline vs reform). - This is a synchronous operation for household calculations. + This is an async operation. The endpoint returns immediately with a report_id + and status="pending". Poll GET /analysis/household-impact/{report_id} until + status="completed" to get results. """ household = validate_household_exists(request.household_id, session) validate_policy_exists(request.policy_id, session) @@ -564,8 +625,10 @@ def household_impact( ) if report.status == ReportStatus.PENDING: - with logfire.span("trigger_household_report", job_id=str(report.id)): - trigger_household_report(report.id, session) + with logfire.span("trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact( + str(report.id), household.tax_benefit_model_name, session + ) return build_household_response(report, baseline_sim, reform_sim, session) From dbb65345be565170049d59e1b1941dd0a8834db8 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 4 Feb 2026 20:06:01 +0300 Subject: [PATCH 08/19] fix: Break up Alembic; add smaller seed scripts --- ...ema.py => 20260204_0001_initial_schema.py} | 84 +-- .../20260204_0002_add_household_support.py | 170 +++++ ...0204_0003_add_parameter_values_indexes.py} | 8 +- scripts/seed.py | 107 +-- scripts/seed_common.py | 362 +++++++++ scripts/seed_policies.py | 143 ++++ scripts/seed_uk_datasets.py | 113 +++ scripts/seed_uk_model.py | 33 + scripts/seed_us_datasets.py | 108 +++ scripts/seed_us_model.py | 33 + src/policyengine_api/modal_app.py | 689 +++++++++++++++++- 11 files changed, 1723 insertions(+), 127 deletions(-) rename alembic/versions/{20260204_d6e30d3b834d_initial_schema.py => 20260204_0001_initial_schema.py} (87%) create mode 100644 alembic/versions/20260204_0002_add_household_support.py rename alembic/versions/{20260204_a17ac554f4aa_add_parameter_values_indexes.py => 20260204_0003_add_parameter_values_indexes.py} (90%) create mode 100644 scripts/seed_common.py create mode 100644 scripts/seed_policies.py create mode 100644 scripts/seed_uk_datasets.py create mode 100644 scripts/seed_uk_model.py create mode 100644 scripts/seed_us_datasets.py create mode 100644 scripts/seed_us_model.py diff --git a/alembic/versions/20260204_d6e30d3b834d_initial_schema.py b/alembic/versions/20260204_0001_initial_schema.py similarity index 87% rename from alembic/versions/20260204_d6e30d3b834d_initial_schema.py rename to alembic/versions/20260204_0001_initial_schema.py index d4de071..273124a 100644 --- a/alembic/versions/20260204_d6e30d3b834d_initial_schema.py +++ b/alembic/versions/20260204_0001_initial_schema.py @@ -1,11 +1,11 @@ -"""Initial schema +"""Initial schema (main branch state) -Revision ID: d6e30d3b834d +Revision ID: 0001_initial Revises: -Create Date: 2026-02-04 02:15:03.471607 +Create Date: 2026-02-04 -This migration creates all base tables for the PolicyEngine API. -Tables are organized by dependency tier to ensure proper creation order. +This migration creates all base tables for the PolicyEngine API as they +exist on the main branch, BEFORE the household CRUD changes. """ from typing import Sequence, Union @@ -14,14 +14,14 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "d6e30d3b834d" +revision: str = "0001_initial" down_revision: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: - """Create all tables.""" + """Create all tables as they exist on main branch.""" # ======================================================================== # TIER 1: Tables with no foreign key dependencies # ======================================================================== @@ -215,33 +215,6 @@ def upgrade() -> None: sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), ) - # Households (stored household definitions) - op.create_table( - "households", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("tax_benefit_model_name", sa.String(), nullable=False), - sa.Column("year", sa.Integer(), nullable=False), - sa.Column("label", sa.String(), nullable=True), - sa.Column("household_data", sa.JSON(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - "idx_households_model_name", "households", ["tax_benefit_model_name"] - ) - op.create_index("idx_households_year", "households", ["year"]) - # ======================================================================== # TIER 4: Tables depending on tier 3 # ======================================================================== @@ -268,13 +241,11 @@ def upgrade() -> None: sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), ) - # Simulations (economy or household calculations) + # Simulations (economy calculations) - NOTE: No household support yet op.create_table( "simulations", sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("simulation_type", sa.String(), nullable=False, default="economy"), - sa.Column("dataset_id", sa.Uuid(), nullable=True), - sa.Column("household_id", sa.Uuid(), nullable=True), + sa.Column("dataset_id", sa.Uuid(), nullable=False), # Required in main sa.Column("policy_id", sa.Uuid(), nullable=True), sa.Column("dynamic_id", sa.Uuid(), nullable=True), sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), @@ -295,10 +266,8 @@ def upgrade() -> None: ), sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("household_result", sa.JSON(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), - sa.ForeignKeyConstraint(["household_id"], ["households.id"]), sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), sa.ForeignKeyConstraint( @@ -307,31 +276,7 @@ def upgrade() -> None: sa.ForeignKeyConstraint(["output_dataset_id"], ["datasets.id"]), ) - # User-household associations - op.create_table( - "user_household_associations", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("user_id", sa.Uuid(), nullable=False), - sa.Column("household_id", sa.Uuid(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), - sa.ForeignKeyConstraint(["household_id"], ["households.id"], ondelete="CASCADE"), - sa.UniqueConstraint("user_id", "household_id"), - ) - op.create_index( - "idx_user_household_user", "user_household_associations", ["user_id"] - ) - op.create_index( - "idx_user_household_household", "user_household_associations", ["household_id"] - ) - - # Household jobs (async household calculations) + # Household jobs (async household calculations) - legacy approach op.create_table( "household_jobs", sa.Column("id", sa.Uuid(), nullable=False), @@ -359,13 +304,12 @@ def upgrade() -> None: # TIER 5: Tables depending on simulations # ======================================================================== - # Reports (analysis reports) + # Reports (analysis reports) - NOTE: No report_type yet op.create_table( "reports", sa.Column("id", sa.Uuid(), nullable=False), sa.Column("label", sa.String(), nullable=False), sa.Column("description", sa.String(), nullable=True), - sa.Column("report_type", sa.String(), nullable=True), sa.Column("user_id", sa.Uuid(), nullable=True), sa.Column("markdown", sa.Text(), nullable=True), sa.Column("parent_report_id", sa.Uuid(), nullable=True), @@ -573,16 +517,10 @@ def downgrade() -> None: # Tier 4 op.drop_table("household_jobs") - op.drop_index("idx_user_household_household", "user_household_associations") - op.drop_index("idx_user_household_user", "user_household_associations") - op.drop_table("user_household_associations") op.drop_table("simulations") op.drop_table("parameter_values") # Tier 3 - op.drop_index("idx_households_year", "households") - op.drop_index("idx_households_model_name", "households") - op.drop_table("households") op.drop_table("dataset_versions") op.drop_table("variables") op.drop_table("parameters") diff --git a/alembic/versions/20260204_0002_add_household_support.py b/alembic/versions/20260204_0002_add_household_support.py new file mode 100644 index 0000000..beb00a0 --- /dev/null +++ b/alembic/versions/20260204_0002_add_household_support.py @@ -0,0 +1,170 @@ +"""Add household CRUD and impact analysis support + +Revision ID: 0002_household +Revises: 0001_initial +Create Date: 2026-02-04 + +This migration adds support for: +- Storing household definitions (households table) +- User-household associations for saved households +- Household-based simulations (adds household_id to simulations) +- Household impact reports (adds report_type to reports) +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "0002_household" +down_revision: Union[str, Sequence[str], None] = "0001_initial" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add household support.""" + # ======================================================================== + # NEW TABLES + # ======================================================================== + + # Households (stored household definitions) + op.create_table( + "households", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("tax_benefit_model_name", sa.String(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("label", sa.String(), nullable=True), + sa.Column("household_data", sa.JSON(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + "idx_households_model_name", "households", ["tax_benefit_model_name"] + ) + op.create_index("idx_households_year", "households", ["year"]) + + # User-household associations (many-to-many for saved households) + op.create_table( + "user_household_associations", + sa.Column("id", sa.Uuid(), nullable=False), + sa.Column("user_id", sa.Uuid(), nullable=False), + sa.Column("household_id", sa.Uuid(), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["household_id"], ["households.id"], ondelete="CASCADE"), + sa.UniqueConstraint("user_id", "household_id"), + ) + op.create_index( + "idx_user_household_user", "user_household_associations", ["user_id"] + ) + op.create_index( + "idx_user_household_household", "user_household_associations", ["household_id"] + ) + + # ======================================================================== + # MODIFY SIMULATIONS TABLE + # ======================================================================== + + # Add simulation_type column (economy vs household) + op.add_column( + "simulations", + sa.Column( + "simulation_type", + sa.String(), + nullable=False, + server_default="economy", + ), + ) + + # Add household_id column (for household simulations) + op.add_column( + "simulations", + sa.Column("household_id", sa.Uuid(), nullable=True), + ) + op.create_foreign_key( + "fk_simulations_household_id", + "simulations", + "households", + ["household_id"], + ["id"], + ) + + # Add household_result column (stores household calculation results) + op.add_column( + "simulations", + sa.Column("household_result", sa.JSON(), nullable=True), + ) + + # Make dataset_id nullable (household simulations don't need a dataset) + op.alter_column( + "simulations", + "dataset_id", + existing_type=sa.Uuid(), + nullable=True, + ) + + # ======================================================================== + # MODIFY REPORTS TABLE + # ======================================================================== + + # Add report_type column (economy_comparison, household_impact, etc.) + op.add_column( + "reports", + sa.Column("report_type", sa.String(), nullable=True), + ) + + +def downgrade() -> None: + """Remove household support.""" + # ======================================================================== + # REVERT REPORTS TABLE + # ======================================================================== + op.drop_column("reports", "report_type") + + # ======================================================================== + # REVERT SIMULATIONS TABLE + # ======================================================================== + + # Make dataset_id required again + op.alter_column( + "simulations", + "dataset_id", + existing_type=sa.Uuid(), + nullable=False, + ) + + # Remove household columns + op.drop_column("simulations", "household_result") + op.drop_constraint("fk_simulations_household_id", "simulations", type_="foreignkey") + op.drop_column("simulations", "household_id") + op.drop_column("simulations", "simulation_type") + + # ======================================================================== + # DROP NEW TABLES + # ======================================================================== + op.drop_index("idx_user_household_household", "user_household_associations") + op.drop_index("idx_user_household_user", "user_household_associations") + op.drop_table("user_household_associations") + + op.drop_index("idx_households_year", "households") + op.drop_index("idx_households_model_name", "households") + op.drop_table("households") diff --git a/alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py b/alembic/versions/20260204_0003_add_parameter_values_indexes.py similarity index 90% rename from alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py rename to alembic/versions/20260204_0003_add_parameter_values_indexes.py index e1967c2..53518cf 100644 --- a/alembic/versions/20260204_a17ac554f4aa_add_parameter_values_indexes.py +++ b/alembic/versions/20260204_0003_add_parameter_values_indexes.py @@ -1,7 +1,7 @@ """Add parameter_values indexes -Revision ID: a17ac554f4aa -Revises: d6e30d3b834d +Revision ID: 0003_param_idx +Revises: 0002_household Create Date: 2026-02-04 02:20:00.000000 This migration adds performance indexes to the parameter_values table @@ -13,8 +13,8 @@ from alembic import op # revision identifiers, used by Alembic. -revision: str = "a17ac554f4aa" -down_revision: Union[str, Sequence[str], None] = "d6e30d3b834d" +revision: str = "0003_param_idx" +down_revision: Union[str, Sequence[str], None] = "0002_household" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None diff --git a/scripts/seed.py b/scripts/seed.py index f3fbfa8..4274528 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -363,7 +363,7 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer return db_version -def seed_datasets(session, lite: bool = False): +def seed_datasets(session, lite: bool = False, skip_uk_datasets: bool = False): """Seed datasets and upload to S3.""" with logfire.span("seed_datasets"): mode_str = " (lite mode - 2026 only)" if lite else "" @@ -383,60 +383,64 @@ def seed_datasets(session, lite: bool = False): ) return - # UK datasets - console.print(" Creating UK datasets...") data_folder = str(Path(__file__).parent.parent / "data") - uk_datasets = ensure_uk_datasets(data_folder=data_folder) - - # In lite mode, only upload FRS 2026 - if lite: - uk_datasets = { - k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k - } - console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + # UK datasets uk_created = 0 uk_skipped = 0 - with logfire.span("seed_uk_datasets", count=len(uk_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("UK datasets", total=len(uk_datasets)) - for _, pe_dataset in uk_datasets.items(): - progress.update(task, description=f"UK: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - uk_skipped += 1 + if skip_uk_datasets: + console.print(" [yellow]Skipping UK datasets (--skip-uk-datasets)[/yellow]") + else: + console.print(" Creating UK datasets...") + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + + # In lite mode, only upload FRS 2026 + if lite: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k + } + console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for _, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + uk_skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + uk_created += 1 progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=uk_model.id, - ) - session.add(db_dataset) - session.commit() - uk_created += 1 - progress.advance(task) - console.print( - f" [green]āœ“[/green] UK: {uk_created} created, {uk_skipped} skipped" - ) + console.print( + f" [green]āœ“[/green] UK: {uk_created} created, {uk_skipped} skipped" + ) # US datasets console.print(" Creating US datasets...") @@ -622,6 +626,11 @@ def main(): action="store_true", help="Lite mode: skip US state parameters, only seed FRS 2026 and CPS 2026 datasets", ) + parser.add_argument( + "--skip-uk-datasets", + action="store_true", + help="Skip UK datasets (useful when HuggingFace token is not available)", + ) args = parser.parse_args() with logfire.span("database_seeding"): @@ -638,7 +647,7 @@ def main(): console.print(f"[green]āœ“[/green] US model seeded: {us_version.id}\n") # Seed datasets - seed_datasets(session, lite=args.lite) + seed_datasets(session, lite=args.lite, skip_uk_datasets=args.skip_uk_datasets) # Seed example policies seed_example_policies(session) diff --git a/scripts/seed_common.py b/scripts/seed_common.py new file mode 100644 index 0000000..4e4e2ec --- /dev/null +++ b/scripts/seed_common.py @@ -0,0 +1,362 @@ +"""Shared utilities for seed scripts.""" + +import io +import json +import logging +import math +import sys +import warnings +from datetime import datetime, timezone +from pathlib import Path +from uuid import uuid4 + +import logfire +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import Session, create_engine + +# Disable all SQLAlchemy and database logging BEFORE any imports +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from policyengine_api.config.settings import settings # noqa: E402 + +# Configure logfire +if settings.logfire_token: + logfire.configure( + token=settings.logfire_token, + environment=settings.logfire_environment, + console=False, + ) + +console = Console() + + +def get_session(): + """Get database session with logging disabled.""" + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): + """Fast bulk insert using PostgreSQL COPY via StringIO.""" + if not rows: + return + + # Get raw psycopg2 connection + connection = session.connection() + raw_conn = connection.connection.dbapi_connection + cursor = raw_conn.cursor() + + # Build CSV-like data in memory + output = io.StringIO() + for row in rows: + values = [] + for col in columns: + val = row[col] + if val is None: + values.append("\\N") + elif isinstance(val, str): + # Escape special characters for COPY + val = ( + val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") + ) + values.append(val) + else: + values.append(str(val)) + output.write("\t".join(values) + "\n") + + output.seek(0) + + # COPY is the fastest way to bulk load PostgreSQL + cursor.copy_from(output, table, columns=columns, null="\\N") + session.commit() + + +def seed_model(model_version, session, lite: bool = False): + """Seed a tax-benefit model with its variables and parameters. + + Returns the TaxBenefitModelVersion that was created or found. + """ + from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelVersion, + ) + from sqlmodel import select + + with logfire.span( + "seed_model", + model=model_version.model.id, + version=model_version.version, + ): + # Create or get the model + console.print(f"[bold blue]Seeding {model_version.model.id}...") + + existing_model = session.exec( + select(TaxBenefitModel).where( + TaxBenefitModel.name == model_version.model.id + ) + ).first() + + if existing_model: + db_model = existing_model + console.print(f" Using existing model: {db_model.id}") + else: + db_model = TaxBenefitModel( + name=model_version.model.id, + description=model_version.model.description, + ) + session.add(db_model) + session.commit() + session.refresh(db_model) + console.print(f" Created model: {db_model.id}") + + # Create model version + existing_version = session.exec( + select(TaxBenefitModelVersion).where( + TaxBenefitModelVersion.model_id == db_model.id, + TaxBenefitModelVersion.version == model_version.version, + ) + ).first() + + if existing_version: + console.print( + f" Model version {model_version.version} already exists, skipping" + ) + return existing_version + + db_version = TaxBenefitModelVersion( + model_id=db_model.id, + version=model_version.version, + description=f"Version {model_version.version}", + ) + session.add(db_version) + session.commit() + session.refresh(db_version) + console.print(f" Created version: {db_version.version}") + + # Add variables + with logfire.span("add_variables", count=len(model_version.variables)): + var_rows = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(model_version.variables)} variables", + total=len(model_version.variables), + ) + for var in model_version.variables: + var_rows.append( + { + "id": uuid4(), + "name": var.name, + "entity": var.entity, + "description": var.description or "", + "data_type": var.data_type.__name__ + if hasattr(var.data_type, "__name__") + else str(var.data_type), + "possible_values": None, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(var_rows)} variables...") + bulk_insert( + session, + "variables", + [ + "id", + "name", + "entity", + "description", + "data_type", + "possible_values", + "tax_benefit_model_version_id", + "created_at", + ], + var_rows, + ) + + console.print( + f" [green]āœ“[/green] Added {len(model_version.variables)} variables" + ) + + # Add parameters (only user-facing ones: those with labels) + # Deduplicate by name - keep first occurrence + # + # WHY DEDUPLICATION IS NEEDED: + # The policyengine package can provide multiple parameter entries with the same + # name. This happens because parameters can have multiple bracket entries or + # state-specific variants that share the same base name. We keep only the first + # occurrence to avoid database unique constraint violations and reduce redundancy. + # + # In lite mode, exclude US state parameters (gov.states.*) + seen_names = set() + parameters_to_add = [] + skipped_state_params = 0 + skipped_no_label = 0 + skipped_duplicate = 0 + + for p in model_version.parameters: + if p.label is None: + skipped_no_label += 1 + continue + if p.name in seen_names: + skipped_duplicate += 1 + continue + # In lite mode, skip state-level parameters for faster seeding + if lite and p.name.startswith("gov.states."): + skipped_state_params += 1 + continue + parameters_to_add.append(p) + seen_names.add(p.name) + + console.print(f" Parameter filtering:") + console.print(f" - Total from source: {len(model_version.parameters)}") + console.print(f" - Skipped (no label): {skipped_no_label}") + console.print(f" - Skipped (duplicate name): {skipped_duplicate}") + if lite and skipped_state_params > 0: + console.print(f" - Skipped (state params, lite mode): {skipped_state_params}") + console.print(f" - To add: {len(parameters_to_add)}") + + with logfire.span("add_parameters", count=len(parameters_to_add)): + # Build list of parameter dicts for bulk insert + param_rows = [] + param_names = [] # Track (pe_id, name, generated_uuid) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameters_to_add)} parameters", + total=len(parameters_to_add), + ) + for param in parameters_to_add: + param_uuid = uuid4() + param_rows.append( + { + "id": param_uuid, + "name": param.name, + "label": param.label if hasattr(param, "label") else None, + "description": param.description or "", + "data_type": param.data_type.__name__ + if hasattr(param.data_type, "__name__") + else str(param.data_type), + "unit": param.unit, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + param_names.append((param.id, param.name, param_uuid)) + progress.advance(task) + + console.print(f" Inserting {len(param_rows)} parameters...") + bulk_insert( + session, + "parameters", + [ + "id", + "name", + "label", + "description", + "data_type", + "unit", + "tax_benefit_model_version_id", + "created_at", + ], + param_rows, + ) + + # Build param_id_map from pre-generated UUIDs + param_id_map = {pe_id: db_uuid for pe_id, name, db_uuid in param_names} + + console.print( + f" [green]āœ“[/green] Added {len(parameters_to_add)} parameters" + ) + + # Add parameter values + # Filter to only include values for parameters we added + parameter_values_to_add = [ + pv + for pv in model_version.parameter_values + if pv.parameter.id in param_id_map + ] + console.print(f" Found {len(parameter_values_to_add)} parameter values to add") + + with logfire.span("add_parameter_values", count=len(parameter_values_to_add)): + pv_rows = [] + skipped = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameter_values_to_add)} parameter values", + total=len(parameter_values_to_add), + ) + for pv in parameter_values_to_add: + # Handle Infinity values - skip them as they can't be stored in JSON + if isinstance(pv.value, float) and ( + math.isinf(pv.value) or math.isnan(pv.value) + ): + skipped += 1 + progress.advance(task) + continue + + # Source data has dates swapped (start > end), fix ordering + # Only swap if both dates are set, otherwise keep original + if pv.start_date and pv.end_date: + start = pv.end_date # Swap: source end is our start + end = pv.start_date # Swap: source start is our end + else: + start = pv.start_date + end = pv.end_date + pv_rows.append( + { + "id": uuid4(), + "parameter_id": param_id_map[pv.parameter.id], + "value_json": json.dumps(pv.value), + "start_date": start, + "end_date": end, + "policy_id": None, + "dynamic_id": None, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(pv_rows)} parameter values...") + bulk_insert( + session, + "parameter_values", + [ + "id", + "parameter_id", + "value_json", + "start_date", + "end_date", + "policy_id", + "dynamic_id", + "created_at", + ], + pv_rows, + ) + + console.print( + f" [green]āœ“[/green] Added {len(pv_rows)} parameter values" + + (f" (skipped {skipped} invalid)" if skipped else "") + ) + + return db_version diff --git a/scripts/seed_policies.py b/scripts/seed_policies.py new file mode 100644 index 0000000..e57b964 --- /dev/null +++ b/scripts/seed_policies.py @@ -0,0 +1,143 @@ +"""Seed example policy reforms for UK and US.""" + +import time +from datetime import datetime, timezone + +import logfire +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + from policyengine_api.models import ( + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, + ) + + console.print("[bold green]Seeding example policies...[/bold green]\n") + + start = time.time() + with get_session() as session: + with logfire.span("seed_example_policies"): + # Get model versions + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not uk_model or not us_model: + console.print( + "[red]Error: UK or US model not found. Run seed_*_model.py first.[/red]" + ) + return + + uk_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == uk_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + us_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == us_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + # UK example policy: raise basic rate to 22p + uk_policy_name = "UK basic rate 22p" + existing_uk_policy = session.exec( + select(Policy).where(Policy.name == uk_policy_name) + ).first() + + if existing_uk_policy: + console.print(f" Policy '{uk_policy_name}' already exists, skipping") + else: + # Find the basic rate parameter + uk_basic_rate_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", + Parameter.tax_benefit_model_version_id == uk_version.id, + ) + ).first() + + if uk_basic_rate_param: + uk_policy = Policy( + name=uk_policy_name, + description="Raise the UK income tax basic rate from 20p to 22p", + ) + session.add(uk_policy) + session.commit() + session.refresh(uk_policy) + + # Add parameter value (22% = 0.22) + uk_param_value = ParameterValue( + parameter_id=uk_basic_rate_param.id, + value_json={"value": 0.22}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=uk_policy.id, + ) + session.add(uk_param_value) + session.commit() + console.print(f" [green]āœ“[/green] Created UK policy: {uk_policy_name}") + else: + console.print( + " [yellow]Warning: UK basic rate parameter not found[/yellow]" + ) + + # US example policy: raise first bracket rate to 12% + us_policy_name = "US 12% lowest bracket" + existing_us_policy = session.exec( + select(Policy).where(Policy.name == us_policy_name) + ).first() + + if existing_us_policy: + console.print(f" Policy '{us_policy_name}' already exists, skipping") + else: + # Find the first bracket rate parameter + us_first_bracket_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.irs.income.bracket.rates.1", + Parameter.tax_benefit_model_version_id == us_version.id, + ) + ).first() + + if us_first_bracket_param: + us_policy = Policy( + name=us_policy_name, + description="Raise US federal income tax lowest bracket to 12%", + ) + session.add(us_policy) + session.commit() + session.refresh(us_policy) + + # Add parameter value (12% = 0.12) + us_param_value = ParameterValue( + parameter_id=us_first_bracket_param.id, + value_json={"value": 0.12}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=us_policy.id, + ) + session.add(us_param_value) + session.commit() + console.print(f" [green]āœ“[/green] Created US policy: {us_policy_name}") + else: + console.print( + " [yellow]Warning: US first bracket parameter not found[/yellow]" + ) + + console.print("[green]āœ“[/green] Example policies seeded") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_uk_datasets.py b/scripts/seed_uk_datasets.py new file mode 100644 index 0000000..1754454 --- /dev/null +++ b/scripts/seed_uk_datasets.py @@ -0,0 +1,113 @@ +"""Seed UK datasets (FRS) and upload to S3. + +NOTE: Requires HUGGING_FACE_TOKEN environment variable to be set, +as UK FRS datasets are hosted on a private HuggingFace repository. +""" + +import argparse +import time +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + parser = argparse.ArgumentParser(description="Seed UK datasets") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: only seed FRS 2026", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.uk.datasets import ( + ensure_datasets as ensure_uk_datasets, + ) + + from policyengine_api.models import Dataset, TaxBenefitModel + from policyengine_api.services.storage import upload_dataset_for_seeding + + console.print("[bold green]Seeding UK datasets...[/bold green]\n") + console.print("[yellow]Note: Requires HUGGING_FACE_TOKEN environment variable[/yellow]\n") + + start = time.time() + with get_session() as session: + # Get UK model + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_uk_model.py first.[/red]") + return + + data_folder = str(Path(__file__).parent.parent / "data") + console.print(f" Data folder: {data_folder}") + + # Get datasets + console.print(" Loading UK datasets from policyengine package...") + ds_start = time.time() + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + console.print(f" Loaded {len(uk_datasets)} datasets in {time.time() - ds_start:.1f}s") + + # In lite mode, only upload FRS 2026 + if args.lite: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k + } + console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + + created = 0 + skipped = 0 + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for name, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + upload_start = time.time() + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + console.print(f"[green]āœ“[/green] UK datasets: {created} created, {skipped} skipped") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_uk_model.py b/scripts/seed_uk_model.py new file mode 100644 index 0000000..07543bf --- /dev/null +++ b/scripts/seed_uk_model.py @@ -0,0 +1,33 @@ +"""Seed UK model (variables, parameters, parameter values).""" + +import argparse +import time + +from seed_common import console, get_session, seed_model + + +def main(): + parser = argparse.ArgumentParser(description="Seed UK model") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: skip state parameters", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.uk import uk_latest + + console.print("[bold green]Seeding UK model...[/bold green]\n") + + start = time.time() + with get_session() as session: + version = seed_model(uk_latest, session, lite=args.lite) + console.print(f"[green]āœ“[/green] UK model seeded: {version.id}") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_us_datasets.py b/scripts/seed_us_datasets.py new file mode 100644 index 0000000..abf1995 --- /dev/null +++ b/scripts/seed_us_datasets.py @@ -0,0 +1,108 @@ +"""Seed US datasets (CPS) and upload to S3.""" + +import argparse +import time +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + parser = argparse.ArgumentParser(description="Seed US datasets") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: only seed CPS 2026", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.us.datasets import ( + ensure_datasets as ensure_us_datasets, + ) + + from policyengine_api.models import Dataset, TaxBenefitModel + from policyengine_api.services.storage import upload_dataset_for_seeding + + console.print("[bold green]Seeding US datasets...[/bold green]\n") + + start = time.time() + with get_session() as session: + # Get US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_us_model.py first.[/red]") + return + + data_folder = str(Path(__file__).parent.parent / "data") + console.print(f" Data folder: {data_folder}") + + # Get datasets + console.print(" Loading US datasets from policyengine package...") + ds_start = time.time() + us_datasets = ensure_us_datasets(data_folder=data_folder) + console.print(f" Loaded {len(us_datasets)} datasets in {time.time() - ds_start:.1f}s") + + # In lite mode, only upload CPS 2026 + if args.lite: + us_datasets = { + k: v for k, v in us_datasets.items() if v.year == 2026 and "cps" in k + } + console.print(f" Lite mode: filtered to {len(us_datasets)} dataset(s)") + + created = 0 + skipped = 0 + + with logfire.span("seed_us_datasets", count=len(us_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US datasets", total=len(us_datasets)) + for name, pe_dataset in us_datasets.items(): + progress.update(task, description=f"US: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + upload_start = time.time() + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + console.print(f"[green]āœ“[/green] US datasets: {created} created, {skipped} skipped") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_us_model.py b/scripts/seed_us_model.py new file mode 100644 index 0000000..ce8a829 --- /dev/null +++ b/scripts/seed_us_model.py @@ -0,0 +1,33 @@ +"""Seed US model (variables, parameters, parameter values).""" + +import argparse +import time + +from seed_common import console, get_session, seed_model + + +def main(): + parser = argparse.ArgumentParser(description="Seed US model") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: skip state parameters", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.us import us_latest + + console.print("[bold green]Seeding US model...[/bold green]\n") + + start = time.time() + with get_session() as session: + version = seed_model(us_latest, session, lite=args.lite) + console.print(f"[green]āœ“[/green] US model seeded: {version.id}") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 1aa8119..14083cf 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -7,7 +7,8 @@ Function naming follows the API hierarchy: - simulate_household_*: Single household calculation (/simulate/household) - simulate_economy_*: Single economy simulation (/simulate/economy) -- economy_comparison_*: Full economy comparison analysis (/analysis/compare/economy) +- economy_comparison_*: Full economy comparison analysis (/analysis/economic-impact) +- household_impact_*: Household impact analysis (/analysis/household-impact) Deploy with: modal deploy src/policyengine_api/modal_app.py """ @@ -2516,3 +2517,689 @@ def compute_change_aggregate_us( raise finally: logfire.force_flush() + + +# ============================================================================= +# Household Impact Functions +# ============================================================================= + + +@app.function( + image=uk_image, + secrets=[db_secrets, logfire_secrets], + memory=2048, + cpu=2, + timeout=300, +) +def household_impact_uk(report_id: str, traceparent: str | None = None) -> None: + """Run UK household impact analysis and write results to database.""" + import logfire + + configure_logfire("policyengine-modal-uk", traceparent) + + try: + with logfire.span("household_impact_uk", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + # Load report + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + # Mark as running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run baseline simulation + if report.baseline_simulation_id: + _run_household_simulation_uk( + report.baseline_simulation_id, session + ) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_household_simulation_uk( + report.reform_simulation_id, session + ) + + # Mark report as completed + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "UK household impact failed", report_id=report_id, error=str(e) + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', error_message = :error " + "WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +def _run_household_simulation_uk(simulation_id, session) -> None: + """Run a single UK household simulation.""" + from datetime import datetime, timezone + + from policyengine_api.models import ( + Household, + Simulation, + SimulationStatus, + ) + + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + # Mark as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + # Get policy data if present + policy_data = _get_household_policy_data(simulation.policy_id, session) + + # Run calculation + result = _calculate_uk_household( + household.household_data, + household.year, + policy_data, + ) + + # Store result + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _calculate_uk_household( + household_data: dict, year: int, policy_data: dict | None +) -> dict: + """Calculate UK household and return result dict.""" + import tempfile + from pathlib import Path + + import pandas as pd + from microdf import MicroDataFrame + from policyengine.core import Simulation + from policyengine.tax_benefit_models.uk import uk_latest + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + UKYearData, + ) + + people = household_data.get("people", []) + benunit = household_data.get("benunit", []) + hh = household_data.get("household", []) + + # Ensure lists + if isinstance(benunit, dict): + benunit = [benunit] + if isinstance(hh, dict): + hh = [hh] + + n_people = len(people) + n_benunits = max(1, len(benunit) if benunit else 1) + n_households = max(1, len(hh) if hh else 1) + + # Build person data + person_data = { + "person_id": list(range(n_people)), + "person_benunit_id": [0] * n_people, + "person_household_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build benunit data + benunit_data = { + "benunit_id": list(range(n_benunits)), + "benunit_weight": [1.0] * n_benunits, + } + for i, bu in enumerate(benunit if benunit else [{}]): + for key, value in bu.items(): + if key not in benunit_data: + benunit_data[key] = [0.0] * n_benunits + benunit_data[key][i] = value + + # Build household data + household_df_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + "region": ["LONDON"] * n_households, + "tenure_type": ["RENT_PRIVATELY"] * n_households, + "council_tax": [0.0] * n_households, + "rent": [0.0] * n_households, + } + for i, h in enumerate(hh if hh else [{}]): + for key, value in h.items(): + if key not in household_df_data: + household_df_data[key] = [0.0] * n_households + household_df_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") + household_df = MicroDataFrame( + pd.DataFrame(household_df_data), weights="household_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUKDataset( + name="Household calculation", + description="Household(s) for calculation", + person=person_df, + benunit=benunit_df, + household=household_df, + filepath=filepath, + year_data_class=UKYearData, + ) + dataset.save() + + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue, Policy + + pe_param_values = [] + param_lookup = {p.name: p for p in uk_latest.parameters} + for pv in policy_data.get("parameter_values", []): + param_name = pv.get("parameter_name") + if param_name and param_name in param_lookup: + pe_pv = ParameterValue( + parameter=param_lookup[param_name], + value=pv.get("value"), + start_date=pv.get("start_date"), + end_date=pv.get("end_date"), + ) + pe_param_values.append(pe_pv) + + if pe_param_values: + policy = Policy( + name=policy_data.get("name", "Reform"), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) + + # Run simulation + sim = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, + ) + sim.ensure() + + # Extract results + result = {"person": [], "benunit": [], "household": []} + + for i in range(n_people): + person_result = {} + for var in sim.output_dataset.person.columns: + val = sim.output_dataset.person[var].iloc[i] + person_result[var] = float(val) if hasattr(val, "item") else val + result["person"].append(person_result) + + for i in range(n_benunits): + benunit_result = {} + for var in sim.output_dataset.benunit.columns: + val = sim.output_dataset.benunit[var].iloc[i] + benunit_result[var] = float(val) if hasattr(val, "item") else val + result["benunit"].append(benunit_result) + + for i in range(n_households): + household_result = {} + for var in sim.output_dataset.household.columns: + val = sim.output_dataset.household[var].iloc[i] + household_result[var] = float(val) if hasattr(val, "item") else val + result["household"].append(household_result) + + return result + + +@app.function( + image=us_image, + secrets=[db_secrets, logfire_secrets], + memory=2048, + cpu=2, + timeout=300, +) +def household_impact_us(report_id: str, traceparent: str | None = None) -> None: + """Run US household impact analysis and write results to database.""" + import logfire + + configure_logfire("policyengine-modal-us", traceparent) + + try: + with logfire.span("household_impact_us", report_id=report_id): + from datetime import datetime, timezone + from uuid import UUID + + from sqlmodel import Session, create_engine + + database_url = get_database_url() + engine = create_engine(database_url) + + try: + from policyengine_api.models import ( + Household, + Report, + ReportStatus, + Simulation, + SimulationStatus, + ) + + with Session(engine) as session: + # Load report + report = session.get(Report, UUID(report_id)) + if not report: + raise ValueError(f"Report {report_id} not found") + + # Mark as running + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + # Run baseline simulation + if report.baseline_simulation_id: + _run_household_simulation_us( + report.baseline_simulation_id, session + ) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_household_simulation_us( + report.reform_simulation_id, session + ) + + # Mark report as completed + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + + except Exception as e: + logfire.error( + "US household impact failed", report_id=report_id, error=str(e) + ) + try: + from sqlmodel import text + + with Session(engine) as session: + session.execute( + text( + "UPDATE reports SET status = 'FAILED', error_message = :error " + "WHERE id = :report_id" + ), + {"report_id": report_id, "error": str(e)[:1000]}, + ) + session.commit() + except Exception as db_error: + logfire.error("Failed to update DB", error=str(db_error)) + raise + finally: + logfire.force_flush() + + +def _run_household_simulation_us(simulation_id, session) -> None: + """Run a single US household simulation.""" + from datetime import datetime, timezone + + from policyengine_api.models import ( + Household, + Simulation, + SimulationStatus, + ) + + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + # Mark as running + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + # Get policy data if present + policy_data = _get_household_policy_data(simulation.policy_id, session) + + # Run calculation + result = _calculate_us_household( + household.household_data, + household.year, + policy_data, + ) + + # Store result + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _calculate_us_household( + household_data: dict, year: int, policy_data: dict | None +) -> dict: + """Calculate US household and return result dict.""" + import tempfile + from pathlib import Path + + import pandas as pd + from microdf import MicroDataFrame + from policyengine.core import Simulation + from policyengine.tax_benefit_models.us import us_latest + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) + + people = household_data.get("people", []) + tax_unit = household_data.get("tax_unit", []) + family = household_data.get("family", []) + spm_unit = household_data.get("spm_unit", []) + marital_unit = household_data.get("marital_unit", []) + hh = household_data.get("household", []) + + # Ensure lists + if isinstance(tax_unit, dict): + tax_unit = [tax_unit] + if isinstance(family, dict): + family = [family] + if isinstance(spm_unit, dict): + spm_unit = [spm_unit] + if isinstance(marital_unit, dict): + marital_unit = [marital_unit] + if isinstance(hh, dict): + hh = [hh] + + n_people = len(people) + n_tax_units = max(1, len(tax_unit) if tax_unit else 1) + n_families = max(1, len(family) if family else 1) + n_spm_units = max(1, len(spm_unit) if spm_unit else 1) + n_marital_units = max(1, len(marital_unit) if marital_unit else 1) + n_households = max(1, len(hh) if hh else 1) + + # Build person data + person_data = { + "person_id": list(range(n_people)), + "person_tax_unit_id": [0] * n_people, + "person_family_id": [0] * n_people, + "person_spm_unit_id": [0] * n_people, + "person_marital_unit_id": [0] * n_people, + "person_household_id": [0] * n_people, + "person_weight": [1.0] * n_people, + } + for i, person in enumerate(people): + for key, value in person.items(): + if key not in person_data: + person_data[key] = [0.0] * n_people + person_data[key][i] = value + + # Build tax_unit data + tax_unit_data = { + "tax_unit_id": list(range(n_tax_units)), + "tax_unit_weight": [1.0] * n_tax_units, + } + for i, tu in enumerate(tax_unit if tax_unit else [{}]): + for key, value in tu.items(): + if key not in tax_unit_data: + tax_unit_data[key] = [0.0] * n_tax_units + tax_unit_data[key][i] = value + + # Build family data + family_data = { + "family_id": list(range(n_families)), + "family_weight": [1.0] * n_families, + } + for i, fam in enumerate(family if family else [{}]): + for key, value in fam.items(): + if key not in family_data: + family_data[key] = [0.0] * n_families + family_data[key][i] = value + + # Build spm_unit data + spm_unit_data = { + "spm_unit_id": list(range(n_spm_units)), + "spm_unit_weight": [1.0] * n_spm_units, + } + for i, spm in enumerate(spm_unit if spm_unit else [{}]): + for key, value in spm.items(): + if key not in spm_unit_data: + spm_unit_data[key] = [0.0] * n_spm_units + spm_unit_data[key][i] = value + + # Build marital_unit data + marital_unit_data = { + "marital_unit_id": list(range(n_marital_units)), + "marital_unit_weight": [1.0] * n_marital_units, + } + for i, mu in enumerate(marital_unit if marital_unit else [{}]): + for key, value in mu.items(): + if key not in marital_unit_data: + marital_unit_data[key] = [0.0] * n_marital_units + marital_unit_data[key][i] = value + + # Build household data + household_df_data = { + "household_id": list(range(n_households)), + "household_weight": [1.0] * n_households, + } + for i, h in enumerate(hh if hh else [{}]): + for key, value in h.items(): + if key not in household_df_data: + household_df_data[key] = [0.0] * n_households + household_df_data[key][i] = value + + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + tax_unit_df = MicroDataFrame( + pd.DataFrame(tax_unit_data), weights="tax_unit_weight" + ) + family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") + spm_unit_df = MicroDataFrame( + pd.DataFrame(spm_unit_data), weights="spm_unit_weight" + ) + marital_unit_df = MicroDataFrame( + pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + ) + household_df = MicroDataFrame( + pd.DataFrame(household_df_data), weights="household_weight" + ) + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUSDataset( + name="Household calculation", + description="Household(s) for calculation", + person=person_df, + tax_unit=tax_unit_df, + family=family_df, + spm_unit=spm_unit_df, + marital_unit=marital_unit_df, + household=household_df, + filepath=filepath, + year_data_class=USYearData, + ) + dataset.save() + + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue, Policy + + pe_param_values = [] + param_lookup = {p.name: p for p in us_latest.parameters} + for pv in policy_data.get("parameter_values", []): + param_name = pv.get("parameter_name") + if param_name and param_name in param_lookup: + pe_pv = ParameterValue( + parameter=param_lookup[param_name], + value=pv.get("value"), + start_date=pv.get("start_date"), + end_date=pv.get("end_date"), + ) + pe_param_values.append(pe_pv) + + if pe_param_values: + policy = Policy( + name=policy_data.get("name", "Reform"), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) + + # Run simulation + sim = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + policy=policy, + ) + sim.ensure() + + # Extract results + result = { + "person": [], + "tax_unit": [], + "family": [], + "spm_unit": [], + "marital_unit": [], + "household": [], + } + + for i in range(n_people): + person_result = {} + for var in sim.output_dataset.person.columns: + val = sim.output_dataset.person[var].iloc[i] + person_result[var] = float(val) if hasattr(val, "item") else val + result["person"].append(person_result) + + for i in range(n_tax_units): + tu_result = {} + for var in sim.output_dataset.tax_unit.columns: + val = sim.output_dataset.tax_unit[var].iloc[i] + tu_result[var] = float(val) if hasattr(val, "item") else val + result["tax_unit"].append(tu_result) + + for i in range(n_families): + fam_result = {} + for var in sim.output_dataset.family.columns: + val = sim.output_dataset.family[var].iloc[i] + fam_result[var] = float(val) if hasattr(val, "item") else val + result["family"].append(fam_result) + + for i in range(n_spm_units): + spm_result = {} + for var in sim.output_dataset.spm_unit.columns: + val = sim.output_dataset.spm_unit[var].iloc[i] + spm_result[var] = float(val) if hasattr(val, "item") else val + result["spm_unit"].append(spm_result) + + for i in range(n_marital_units): + mu_result = {} + for var in sim.output_dataset.marital_unit.columns: + val = sim.output_dataset.marital_unit[var].iloc[i] + mu_result[var] = float(val) if hasattr(val, "item") else val + result["marital_unit"].append(mu_result) + + for i in range(n_households): + hh_result = {} + for var in sim.output_dataset.household.columns: + val = sim.output_dataset.household[var].iloc[i] + hh_result[var] = float(val) if hasattr(val, "item") else val + result["household"].append(hh_result) + + return result + + +def _get_household_policy_data(policy_id, session) -> dict | None: + """Get policy data for household calculation.""" + if policy_id is None: + return None + + from policyengine_api.models import Policy + + db_policy = session.get(Policy, policy_id) + if not db_policy: + return None + + return { + "name": db_policy.name, + "description": db_policy.description, + "parameter_values": [ + { + "parameter_name": pv.parameter.name if pv.parameter else None, + "value": pv.value_json.get("value") + if isinstance(pv.value_json, dict) + else pv.value_json, + "start_date": pv.start_date.isoformat() if pv.start_date else None, + "end_date": pv.end_date.isoformat() if pv.end_date else None, + } + for pv in db_policy.parameter_values + if pv.parameter + ], + } From cf8511f31f526e0e46c6baaf0ba841bcef8886c0 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Thu, 5 Feb 2026 17:28:24 +0300 Subject: [PATCH 09/19] test: Tests --- scripts/seed_common.py | 11 +- scripts/test_economy_simulation.py | 277 +++++++++++ scripts/test_household_scenarios.py | 344 +++++++++++++ src/policyengine_api/api/household.py | 388 +++++++++------ .../api/household_analysis.py | 33 +- src/policyengine_api/modal_app.py | 464 ++++++++++++++++-- test_fixtures/fixtures_policy_reform.py | 282 +++++++++++ tests/test_policy_reform.py | 327 ++++++++++++ 8 files changed, 1929 insertions(+), 197 deletions(-) create mode 100644 scripts/test_economy_simulation.py create mode 100644 scripts/test_household_scenarios.py create mode 100644 test_fixtures/fixtures_policy_reform.py create mode 100644 tests/test_policy_reform.py diff --git a/scripts/seed_common.py b/scripts/seed_common.py index 4e4e2ec..f6d7ab6 100644 --- a/scripts/seed_common.py +++ b/scripts/seed_common.py @@ -189,8 +189,7 @@ def seed_model(model_version, session, lite: bool = False): f" [green]āœ“[/green] Added {len(model_version.variables)} variables" ) - # Add parameters (only user-facing ones: those with labels) - # Deduplicate by name - keep first occurrence + # Add parameters - deduplicate by name (keep first occurrence) # # WHY DEDUPLICATION IS NEEDED: # The policyengine package can provide multiple parameter entries with the same @@ -198,17 +197,16 @@ def seed_model(model_version, session, lite: bool = False): # state-specific variants that share the same base name. We keep only the first # occurrence to avoid database unique constraint violations and reduce redundancy. # + # NOTE: We do NOT filter by label. Parameters without labels (bracket params, + # breakdown params) are still valid and needed for policy analysis. + # # In lite mode, exclude US state parameters (gov.states.*) seen_names = set() parameters_to_add = [] skipped_state_params = 0 - skipped_no_label = 0 skipped_duplicate = 0 for p in model_version.parameters: - if p.label is None: - skipped_no_label += 1 - continue if p.name in seen_names: skipped_duplicate += 1 continue @@ -221,7 +219,6 @@ def seed_model(model_version, session, lite: bool = False): console.print(f" Parameter filtering:") console.print(f" - Total from source: {len(model_version.parameters)}") - console.print(f" - Skipped (no label): {skipped_no_label}") console.print(f" - Skipped (duplicate name): {skipped_duplicate}") if lite and skipped_state_params > 0: console.print(f" - Skipped (state params, lite mode): {skipped_state_params}") diff --git a/scripts/test_economy_simulation.py b/scripts/test_economy_simulation.py new file mode 100644 index 0000000..3845fc4 --- /dev/null +++ b/scripts/test_economy_simulation.py @@ -0,0 +1,277 @@ +"""Test economy-wide simulation following the exact flow from modal_app.py. + +This script mimics the economy-wide simulation code path as closely as possible +to verify whether policy reforms are being applied correctly. +""" + +import tempfile +from datetime import datetime +from pathlib import Path + +import pandas as pd +from microdf import MicroDataFrame + +# Import exactly as modal_app.py does +from policyengine.core import Simulation as PESimulation +from policyengine.core.policy import ParameterValue as PEParameterValue +from policyengine.core.policy import Policy as PEPolicy +from policyengine.tax_benefit_models.us import us_latest +from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset, USYearData + + +def create_test_dataset(year: int) -> PolicyEngineUSDataset: + """Create a small test dataset similar to what would be loaded from storage. + + Uses the same structure as economy-wide datasets but with just a few households. + """ + # Create 3 test households with different income levels + # Each household has 2 adults + 2 children (to test CTC) + n_households = 3 + n_people = n_households * 4 # 4 people per household + + # Person data + person_data = { + "person_id": list(range(n_people)), + "person_household_id": [i // 4 for i in range(n_people)], + "person_marital_unit_id": [], + "person_family_id": [i // 4 for i in range(n_people)], + "person_spm_unit_id": [i // 4 for i in range(n_people)], + "person_tax_unit_id": [i // 4 for i in range(n_people)], + "person_weight": [1000.0] * n_people, # Weight for population scaling + "age": [], + "employment_income": [], + } + + # Build person details + marital_unit_counter = 0 + for hh in range(n_households): + base_income = 10000 + (hh * 20000) # 10k, 30k, 50k + # Adult 1 + person_data["age"].append(35) + person_data["employment_income"].append(base_income) + person_data["person_marital_unit_id"].append(marital_unit_counter) + # Adult 2 + person_data["age"].append(33) + person_data["employment_income"].append(0) + person_data["person_marital_unit_id"].append(marital_unit_counter) + marital_unit_counter += 1 + # Child 1 + person_data["age"].append(5) + person_data["employment_income"].append(0) + person_data["person_marital_unit_id"].append(marital_unit_counter) + marital_unit_counter += 1 + # Child 2 + person_data["age"].append(3) + person_data["employment_income"].append(0) + person_data["person_marital_unit_id"].append(marital_unit_counter) + marital_unit_counter += 1 + + n_marital_units = marital_unit_counter + + # Entity data + household_data = { + "household_id": list(range(n_households)), + "household_weight": [1000.0] * n_households, + "state_fips": [48] * n_households, # Texas + } + + marital_unit_data = { + "marital_unit_id": list(range(n_marital_units)), + "marital_unit_weight": [1000.0] * n_marital_units, + } + + family_data = { + "family_id": list(range(n_households)), + "family_weight": [1000.0] * n_households, + } + + spm_unit_data = { + "spm_unit_id": list(range(n_households)), + "spm_unit_weight": [1000.0] * n_households, + } + + tax_unit_data = { + "tax_unit_id": list(range(n_households)), + "tax_unit_weight": [1000.0] * n_households, + } + + # Create MicroDataFrames (same as economy datasets) + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + household_df = MicroDataFrame(pd.DataFrame(household_data), weights="household_weight") + marital_unit_df = MicroDataFrame(pd.DataFrame(marital_unit_data), weights="marital_unit_weight") + family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") + spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") + tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") + + # Create dataset file + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "test_economy.h5") + + return PolicyEngineUSDataset( + name="Test Economy Dataset", + description="Small test dataset for economy simulation", + filepath=filepath, + year=year, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), + ) + + +def create_policy_like_modal_app(model_version) -> PEPolicy: + """Create a policy exactly like _get_pe_policy_us does in modal_app.py. + + This mimics the exact flow: + 1. Look up parameter by name from model_version.parameters + 2. Create PEParameterValue with the parameter, value, start_date, end_date + 3. Create PEPolicy with the parameter values + """ + param_lookup = {p.name: p for p in model_version.parameters} + + # This is exactly what _get_pe_policy_us does + pe_param = param_lookup.get("gov.irs.credits.ctc.refundable.fully_refundable") + if not pe_param: + raise ValueError("Parameter not found!") + + pe_pv = PEParameterValue( + parameter=pe_param, + value=True, # Make CTC fully refundable + start_date=datetime(2024, 1, 1), + end_date=None, + ) + + return PEPolicy( + name="CTC Fully Refundable", + description="Makes CTC fully refundable", + parameter_values=[pe_pv], + ) + + +def run_economy_simulation(dataset: PolicyEngineUSDataset, policy: PEPolicy | None, label: str) -> dict: + """Run an economy simulation exactly like modal_app.py does. + + This follows the exact flow from simulate_economy_us: + 1. Create PESimulation with dataset, model version, policy, dynamic + 2. Call pe_sim.ensure() (which calls run() internally) + 3. Access output via pe_sim.output_dataset + """ + print(f"\n=== {label} ===") + print(f" Policy: {policy.name if policy else 'None (baseline)'}") + if policy: + print(f" Policy parameter_values: {len(policy.parameter_values)}") + for pv in policy.parameter_values: + print(f" - {pv.parameter.name}: {pv.value} (start: {pv.start_date})") + + pe_model_version = us_latest + + # Create and run simulation - EXACTLY like modal_app.py lines 1006-1012 + pe_sim = PESimulation( + dataset=dataset, + tax_benefit_model_version=pe_model_version, + policy=policy, + dynamic=None, + ) + pe_sim.ensure() + + # Extract results from output dataset + output_data = pe_sim.output_dataset.data + + # Sum up key metrics across all tax units (weighted) + tax_unit_df = pd.DataFrame(output_data.tax_unit) + + # Get the variables we care about + total_ctc = 0 + total_income_tax = 0 + total_eitc = 0 + + for var in ["ctc", "income_tax", "eitc"]: + if var in tax_unit_df.columns: + # Weighted sum + weights = tax_unit_df.get("tax_unit_weight", pd.Series([1.0] * len(tax_unit_df))) + if var == "ctc": + total_ctc = (tax_unit_df[var] * weights).sum() + elif var == "income_tax": + total_income_tax = (tax_unit_df[var] * weights).sum() + elif var == "eitc": + total_eitc = (tax_unit_df[var] * weights).sum() + + print(f" Results (weighted totals across {len(tax_unit_df)} tax units):") + print(f" Total CTC: ${total_ctc:,.0f}") + print(f" Total Income Tax: ${total_income_tax:,.0f}") + print(f" Total EITC: ${total_eitc:,.0f}") + + # Also show per-household breakdown + print(f" Per tax unit breakdown:") + for i in range(len(tax_unit_df)): + ctc = tax_unit_df["ctc"].iloc[i] if "ctc" in tax_unit_df.columns else 0 + income_tax = tax_unit_df["income_tax"].iloc[i] if "income_tax" in tax_unit_df.columns else 0 + print(f" Tax Unit {i}: CTC=${ctc:,.0f}, Income Tax=${income_tax:,.0f}") + + return { + "total_ctc": total_ctc, + "total_income_tax": total_income_tax, + "total_eitc": total_eitc, + "tax_unit_df": tax_unit_df, + } + + +def main(): + print("=" * 60) + print("ECONOMY-WIDE SIMULATION TEST") + print("Following the exact code path from modal_app.py") + print("=" * 60) + + year = 2024 + + # Create test dataset (same for both simulations) + print("\nCreating test dataset...") + + # Run baseline simulation + baseline_dataset = create_test_dataset(year) + baseline_results = run_economy_simulation(baseline_dataset, None, "BASELINE (no policy)") + + # Create policy exactly like modal_app.py does + policy = create_policy_like_modal_app(us_latest) + + # Run reform simulation + reform_dataset = create_test_dataset(year) + reform_results = run_economy_simulation(reform_dataset, policy, "REFORM (CTC fully refundable)") + + # Compare results + print("\n" + "=" * 60) + print("COMPARISON") + print("=" * 60) + + ctc_diff = reform_results["total_ctc"] - baseline_results["total_ctc"] + tax_diff = reform_results["total_income_tax"] - baseline_results["total_income_tax"] + + print(f"\nTotal CTC:") + print(f" Baseline: ${baseline_results['total_ctc']:,.0f}") + print(f" Reform: ${reform_results['total_ctc']:,.0f}") + print(f" Change: ${ctc_diff:,.0f}") + + print(f"\nTotal Income Tax:") + print(f" Baseline: ${baseline_results['total_income_tax']:,.0f}") + print(f" Reform: ${reform_results['total_income_tax']:,.0f}") + print(f" Change: ${tax_diff:,.0f}") + + # Verdict + print("\n" + "=" * 60) + print("VERDICT") + print("=" * 60) + + if baseline_results["total_income_tax"] == reform_results["total_income_tax"]: + print("\nāŒ BUG CONFIRMED: Results are IDENTICAL!") + print(" The policy reform is NOT being applied to economy simulations.") + else: + print("\nāœ“ NO BUG: Results differ as expected!") + print(f" The fully refundable CTC reform changed income tax by ${tax_diff:,.0f}") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_household_scenarios.py b/scripts/test_household_scenarios.py new file mode 100644 index 0000000..fb418a4 --- /dev/null +++ b/scripts/test_household_scenarios.py @@ -0,0 +1,344 @@ +"""Test household calculation scenarios. + +Tests: +1. US California household under current law +2. Scotland household under current law +3. US household: current law vs CTC fully refundable reform +""" + +import sys +import time +import requests + +BASE_URL = "http://127.0.0.1:8000" + + +def poll_for_completion(report_id: str, max_attempts: int = 60) -> dict: + """Poll until report is completed or failed.""" + for attempt in range(max_attempts): + resp = requests.get(f"{BASE_URL}/analysis/household-impact/{report_id}") + if resp.status_code != 200: + raise Exception(f"Failed to get report: {resp.status_code} - {resp.text}") + + result = resp.json() + status = result["status"].upper() + + if status == "COMPLETED": + return result + elif status == "FAILED": + raise Exception(f"Report failed: {result.get('error_message', 'Unknown error')}") + + time.sleep(0.5) + + raise Exception(f"Timed out after {max_attempts} attempts") + + +def print_household_summary(result: dict, label: str): + """Print summary of household calculation result.""" + print(f"\n {label}:") + + baseline = result.get("baseline_result", {}) + reform = result.get("reform_result", {}) + + # Get key metrics from person/household + if "person" in baseline and baseline["person"]: + person = baseline["person"][0] + if "household_net_income" in person: + baseline_income = person["household_net_income"] + print(f" Baseline net income: ${baseline_income:,.2f}") + + if reform and "person" in reform and reform["person"]: + reform_income = reform["person"][0].get("household_net_income", 0) + print(f" Reform net income: ${reform_income:,.2f}") + print(f" Difference: ${reform_income - baseline_income:,.2f}") + + # Show some tax/benefit info if available + for key in ["income_tax", "federal_income_tax", "state_income_tax", "ctc", "refundable_ctc"]: + if key in person: + print(f" {key}: ${person[key]:,.2f}") + + +def test_us_california(): + """Test 1: US California household under current law.""" + print("\n" + "=" * 60) + print("TEST 1: US California Household - Current Law") + print("=" * 60) + + # Create California household + household_data = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "California test household", + "people": [ + {"age": 35, "employment_income": 75000}, + {"age": 33, "employment_income": 45000}, + {"age": 8}, # Child + ], + "tax_unit": {}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_code": "CA"}, + } + + print("\n Creating household...") + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + return None + + household = resp.json() + household_id = household["id"] + print(f" Household ID: {household_id}") + + # Run analysis under current law (no policy_id) + print(" Running analysis...") + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ + "household_id": household_id, + "policy_id": None, + }) + + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + return household_id + + report_id = resp.json()["report_id"] + print(f" Report ID: {report_id}") + + # Poll for results + try: + result = poll_for_completion(report_id) + print(" Status: COMPLETED") + print_household_summary(result, "Results") + except Exception as e: + print(f" FAILED: {e}") + + return household_id + + +def test_scotland(): + """Test 2: Scotland household under current law.""" + print("\n" + "=" * 60) + print("TEST 2: Scotland Household - Current Law") + print("=" * 60) + + # Create Scotland household + household_data = { + "tax_benefit_model_name": "policyengine_uk", + "year": 2024, + "label": "Scotland test household", + "people": [ + {"age": 40, "employment_income": 45000}, + ], + "benunit": {}, + "household": {"region": "SCOTLAND"}, + } + + print("\n Creating household...") + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + return None + + household = resp.json() + household_id = household["id"] + print(f" Household ID: {household_id}") + + # Run analysis under current law + print(" Running analysis...") + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ + "household_id": household_id, + "policy_id": None, + }) + + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + return household_id + + report_id = resp.json()["report_id"] + print(f" Report ID: {report_id}") + + # Poll for results + try: + result = poll_for_completion(report_id) + print(" Status: COMPLETED") + print_household_summary(result, "Results") + except Exception as e: + print(f" FAILED: {e}") + + return household_id + + +def test_us_ctc_reform(): + """Test 3: US household - current law vs CTC fully refundable.""" + print("\n" + "=" * 60) + print("TEST 3: US Household - Current Law vs CTC Fully Refundable") + print("=" * 60) + + # First, find the CTC refundability parameter + print("\n Finding CTC refundability parameter...") + resp = requests.get(f"{BASE_URL}/parameters", params={"search": "ctc", "limit": 50}) + if resp.status_code != 200: + print(f" FAILED to search parameters: {resp.status_code}") + return None, None + + params = resp.json() + ctc_param = None + for p in params: + # Look for the refundable portion parameter + if "refundable" in p["name"].lower() and "ctc" in p["name"].lower(): + print(f" Found: {p['name']} (label: {p.get('label')})") + ctc_param = p + break + + if not ctc_param: + # Try searching for child tax credit parameters + print(" Searching for child_tax_credit parameters...") + resp = requests.get(f"{BASE_URL}/parameters", params={"search": "child_tax_credit", "limit": 50}) + params = resp.json() + for p in params: + print(f" - {p['name']}") + if "refundable" in p["name"].lower(): + ctc_param = p + break + + if not ctc_param: + print(" Could not find CTC refundability parameter") + print(" Continuing with household creation anyway...") + + # Create household with children (needed for CTC) + household_data = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "CTC test household", + "people": [ + {"age": 35, "employment_income": 30000}, # Lower income to see CTC effect + {"age": 33, "employment_income": 0}, + {"age": 5}, # Child 1 + {"age": 3}, # Child 2 + ], + "tax_unit": {}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_code": "TX"}, # Texas - no state income tax + } + + print("\n Creating household...") + resp = requests.post(f"{BASE_URL}/households/", json=household_data) + if resp.status_code != 201: + print(f" FAILED: {resp.status_code} - {resp.text}") + return None, None + + household = resp.json() + household_id = household["id"] + print(f" Household ID: {household_id}") + + # Create a policy that makes CTC fully refundable + policy_id = None + if ctc_param: + print("\n Creating CTC fully refundable policy...") + policy_data = { + "name": "CTC Fully Refundable", + "description": "Makes the Child Tax Credit fully refundable", + } + resp = requests.post(f"{BASE_URL}/policies/", json=policy_data) + if resp.status_code == 201: + policy = resp.json() + policy_id = policy["id"] + print(f" Policy ID: {policy_id}") + + # Add parameter value to make CTC fully refundable + # The parameter should set refundable portion to 100% or max amount + pv_data = { + "parameter_id": ctc_param["id"], + "value_json": 1.0, # 100% refundable + "start_date": "2024-01-01T00:00:00Z", + "end_date": None, + "policy_id": policy_id, + } + resp = requests.post(f"{BASE_URL}/parameter-values/", json=pv_data) + if resp.status_code == 201: + print(" Added parameter value for full refundability") + else: + print(f" Warning: Failed to add parameter value: {resp.status_code} - {resp.text}") + else: + print(f" Warning: Failed to create policy: {resp.status_code}") + + # Run analysis with reform policy + print("\n Running analysis (baseline vs reform)...") + resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ + "household_id": household_id, + "policy_id": policy_id, + }) + + if resp.status_code != 200: + print(f" FAILED: {resp.status_code} - {resp.text}") + return household_id, policy_id + + report_id = resp.json()["report_id"] + print(f" Report ID: {report_id}") + + # Poll for results + try: + result = poll_for_completion(report_id) + print(" Status: COMPLETED") + print_household_summary(result, "Results") + except Exception as e: + print(f" FAILED: {e}") + + return household_id, policy_id + + +def main(): + print("=" * 60) + print("HOUSEHOLD CALCULATION SCENARIO TESTS") + print("=" * 60) + + # Track created resources for cleanup + households = [] + policies = [] + + # Test 1: US California + hh_id = test_us_california() + if hh_id: + households.append(hh_id) + + # Test 2: Scotland + hh_id = test_scotland() + if hh_id: + households.append(hh_id) + + # Test 3: CTC Reform + hh_id, policy_id = test_us_ctc_reform() + if hh_id: + households.append(hh_id) + if policy_id: + policies.append(policy_id) + + # Cleanup + print("\n" + "=" * 60) + print("CLEANUP") + print("=" * 60) + + for hh_id in households: + resp = requests.delete(f"{BASE_URL}/households/{hh_id}") + if resp.status_code == 204: + print(f" Deleted household: {hh_id}") + else: + print(f" Warning: Failed to delete household {hh_id}: {resp.status_code}") + + for policy_id in policies: + resp = requests.delete(f"{BASE_URL}/policies/{policy_id}") + if resp.status_code == 204: + print(f" Deleted policy: {policy_id}") + else: + print(f" Warning: Failed to delete policy {policy_id}: {resp.status_code}") + + print("\n" + "=" * 60) + print("TESTS COMPLETE") + print("=" * 60) + + +if __name__ == "__main__": + main() diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index 0e89b5e..adb6ac9 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -294,17 +294,16 @@ def _calculate_household_uk( Supports multiple households via entity relational dataframes. If entity IDs are not provided, defaults to single household with all people in it. - """ - import tempfile - from datetime import datetime - from pathlib import Path + Uses policyengine-uk Microsimulation directly with reform dict to ensure + policy changes are applied correctly. + """ + import numpy as np import pandas as pd - from policyengine.core import Simulation - from microdf import MicroDataFrame from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset - from policyengine.tax_benefit_models.uk.datasets import UKYearData + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_uk import Microsimulation + from policyengine_uk.system import system n_people = len(people) n_benunits = max(1, len(benunit)) @@ -350,68 +349,88 @@ def _calculate_household_uk( household_data[key] = [0.0] * n_households household_data[key][i] = value - # Create MicroDataFrames - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") - household_df = MicroDataFrame( - pd.DataFrame(household_data), weights="household_weight" + # Convert policy_data to policyengine-uk reform dict format + # Format: {"param.name": {"YYYY-MM-DD": value}} + reform = None + if policy_data and policy_data.get("parameter_values"): + reform = {} + for pv in policy_data["parameter_values"]: + param_name = pv.get("parameter_name") + value = pv.get("value") + start_date = pv.get("start_date") + + if param_name and start_date: + # Parse ISO date string to get just the date part + if "T" in start_date: + date_str = start_date.split("T")[0] + else: + date_str = start_date + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + # Create Microsimulation with reform applied at construction time + microsim = Microsimulation(reform=reform) + + # Build simulation from entity data using SimulationBuilder + person_df = pd.DataFrame(person_data) + + # Determine column naming convention + benunit_id_col = ( + "person_benunit_id" + if "person_benunit_id" in person_df.columns + else "benunit_id" ) - - # Create temporary dataset - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "household_calc.h5") - - dataset = PolicyEngineUKDataset( - name="Household calculation", - description="Household(s) for calculation", - filepath=filepath, - year=year, - data=UKYearData( - person=person_df, - benunit=benunit_df, - household=household_df, - ), + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" ) - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue as PEParameterValue - from policyengine.core.policy import Policy as PEPolicy - - pe_param_values = [] - param_lookup = {p.name: p for p in uk_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) + # Declare entities using SimulationBuilder + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("benunit", np.unique(person_df[benunit_id_col].values)) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) - # Run simulation - simulation = Simulation( - dataset=dataset, - tax_benefit_model_version=uk_latest, - policy=policy, + # Join persons to group entities + builder.join_with_persons( + builder.populations["benunit"], + person_df[benunit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), ) - simulation.run() - # Extract outputs - output_data = simulation.output_dataset.data + # Build simulation from populations + microsim.build_from_populations(builder.populations) + # Set input variables for each entity + id_columns = { + "person_id", + "benunit_id", + "person_benunit_id", + "household_id", + "person_household_id", + } + + for entity_name, entity_df in [ + ("person", person_data), + ("benunit", benunit_data), + ("household", household_data), + ]: + df = pd.DataFrame(entity_df) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables def safe_convert(value): try: return float(value) @@ -422,21 +441,24 @@ def safe_convert(value): for i in range(n_people): person_dict = {} for var in uk_latest.entity_variables["person"]: - person_dict[var] = safe_convert(output_data.person[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to="person") + person_dict[var] = safe_convert(val.values[i]) person_outputs.append(person_dict) benunit_outputs = [] - for i in range(len(output_data.benunit)): + for i in range(n_benunits): benunit_dict = {} for var in uk_latest.entity_variables["benunit"]: - benunit_dict[var] = safe_convert(output_data.benunit[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to="benunit") + benunit_dict[var] = safe_convert(val.values[i]) benunit_outputs.append(benunit_dict) household_outputs = [] - for i in range(len(output_data.household)): + for i in range(n_households): household_dict = {} for var in uk_latest.entity_variables["household"]: - household_dict[var] = safe_convert(output_data.household[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to="household") + household_dict[var] = safe_convert(val.values[i]) household_outputs.append(household_dict) return { @@ -466,7 +488,14 @@ def _run_local_household_us( try: result = _calculate_household_us( - people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data + people, + marital_unit, + family, + spm_unit, + tax_unit, + household, + year, + policy_data, ) # Update job with result @@ -506,17 +535,16 @@ def _calculate_household_us( Supports multiple households via entity relational dataframes. If entity IDs are not provided, defaults to single household with all people in it. - """ - import tempfile - from datetime import datetime - from pathlib import Path + Uses policyengine-us Microsimulation directly with reform dict to ensure + policy changes are applied correctly. + """ + import numpy as np import pandas as pd - from policyengine.core import Simulation - from microdf import MicroDataFrame from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset - from policyengine.tax_benefit_models.us.datasets import USYearData + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_us import Microsimulation + from policyengine_us.system import system n_people = len(people) n_households = max(1, len(household)) @@ -596,108 +624,158 @@ def _calculate_household_us( tax_unit_data[key] = [0.0] * n_tax_units tax_unit_data[key][i] = value - # Create MicroDataFrames - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - household_df = MicroDataFrame( - pd.DataFrame(household_data), weights="household_weight" + # Convert policy_data to policyengine-us reform dict format + # Format: {"param.name": {"YYYY-MM-DD": value}} + reform = None + if policy_data and policy_data.get("parameter_values"): + reform = {} + for pv in policy_data["parameter_values"]: + param_name = pv.get("parameter_name") + value = pv.get("value") + start_date = pv.get("start_date") + + if param_name and start_date: + # Parse ISO date string to get just the date part + if "T" in start_date: + date_str = start_date.split("T")[0] + else: + date_str = start_date + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + # Create Microsimulation with reform applied at construction time + # This ensures the reform is properly integrated into the tax benefit system + microsim = Microsimulation(reform=reform) + + # Build simulation from entity data using SimulationBuilder + person_df = pd.DataFrame(person_data) + + # Determine column naming convention + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" ) - marital_unit_df = MicroDataFrame( - pd.DataFrame(marital_unit_data), weights="marital_unit_weight" + marital_unit_id_col = ( + "person_marital_unit_id" + if "person_marital_unit_id" in person_df.columns + else "marital_unit_id" ) - family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") - spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") - tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") - - # Create temporary dataset - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "household_calc.h5") - - dataset = PolicyEngineUSDataset( - name="Household calculation", - description="Household(s) for calculation", - filepath=filepath, - year=year, - data=USYearData( - person=person_df, - household=household_df, - marital_unit=marital_unit_df, - family=family_df, - spm_unit=spm_unit_df, - tax_unit=tax_unit_df, - ), + family_id_col = ( + "person_family_id" if "person_family_id" in person_df.columns else "family_id" + ) + spm_unit_id_col = ( + "person_spm_unit_id" + if "person_spm_unit_id" in person_df.columns + else "spm_unit_id" + ) + tax_unit_id_col = ( + "person_tax_unit_id" + if "person_tax_unit_id" in person_df.columns + else "tax_unit_id" ) - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue as PEParameterValue - from policyengine.core.policy import Policy as PEPolicy - - pe_param_values = [] - param_lookup = {p.name: p for p in us_latest.parameters} - for pv in policy_data.get("parameter_values", []): - pe_param = param_lookup.get(pv["parameter_name"]) - if pe_param: - pe_pv = PEParameterValue( - parameter=pe_param, - value=pv["value"], - start_date=datetime.fromisoformat(pv["start_date"]) - if pv.get("start_date") - else None, - end_date=datetime.fromisoformat(pv["end_date"]) - if pv.get("end_date") - else None, - ) - pe_param_values.append(pe_pv) - policy = PEPolicy( - name=policy_data.get("name", ""), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) + # Declare entities using SimulationBuilder + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + builder.declare_entity("spm_unit", np.unique(person_df[spm_unit_id_col].values)) + builder.declare_entity("family", np.unique(person_df[family_id_col].values)) + builder.declare_entity("tax_unit", np.unique(person_df[tax_unit_id_col].values)) + builder.declare_entity( + "marital_unit", np.unique(person_df[marital_unit_id_col].values) + ) - # Run simulation - simulation = Simulation( - dataset=dataset, - tax_benefit_model_version=us_latest, - policy=policy, + # Join persons to group entities + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["spm_unit"], + person_df[spm_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["family"], + person_df[family_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["tax_unit"], + person_df[tax_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["marital_unit"], + person_df[marital_unit_id_col].values, + np.array(["member"] * len(person_df)), ) - simulation.run() - # Extract outputs - output_data = simulation.output_dataset.data + # Build simulation from populations + microsim.build_from_populations(builder.populations) + + # Set input variables for each entity + id_columns = { + "person_id", + "household_id", + "person_household_id", + "spm_unit_id", + "person_spm_unit_id", + "family_id", + "person_family_id", + "tax_unit_id", + "person_tax_unit_id", + "marital_unit_id", + "person_marital_unit_id", + } + for entity_name, entity_df in [ + ("person", person_data), + ("household", household_data), + ("spm_unit", spm_unit_data), + ("family", family_data), + ("tax_unit", tax_unit_data), + ("marital_unit", marital_unit_data), + ]: + df = pd.DataFrame(entity_df) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables def safe_convert(value): try: return float(value) except (ValueError, TypeError): return str(value) - def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]: + def extract_entity_outputs( + entity_name: str, n_rows: int, map_to: str + ) -> list[dict]: outputs = [] for i in range(n_rows): row_dict = {} for var in us_latest.entity_variables[entity_name]: - row_dict[var] = safe_convert(entity_data[var].iloc[i]) + val = microsim.calculate(var, period=year, map_to=map_to) + row_dict[var] = safe_convert(val.values[i]) outputs.append(row_dict) return outputs return { - "person": extract_entity_outputs("person", output_data.person, n_people), + "person": extract_entity_outputs("person", n_people, "person"), "marital_unit": extract_entity_outputs( - "marital_unit", output_data.marital_unit, len(output_data.marital_unit) - ), - "family": extract_entity_outputs( - "family", output_data.family, len(output_data.family) - ), - "spm_unit": extract_entity_outputs( - "spm_unit", output_data.spm_unit, len(output_data.spm_unit) - ), - "tax_unit": extract_entity_outputs( - "tax_unit", output_data.tax_unit, len(output_data.tax_unit) - ), - "household": extract_entity_outputs( - "household", output_data.household, len(output_data.household) + "marital_unit", n_marital_units, "marital_unit" ), + "family": extract_entity_outputs("family", n_families, "family"), + "spm_unit": extract_entity_outputs("spm_unit", n_spm_units, "spm_unit"), + "tax_unit": extract_entity_outputs("tax_unit", n_tax_units, "tax_unit"), + "household": extract_entity_outputs("household", n_households, "household"), } diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index 981d968..5f36fda 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -158,22 +158,45 @@ def _ensure_list(value: Any) -> list: def _extract_policy_data(policy: Policy | None) -> dict | None: - """Extract policy data from a Policy model into calculation format.""" + """Extract policy data from a Policy model into calculation format. + + Returns format expected by _calculate_household_us/_calculate_household_uk: + { + "name": "policy name", + "description": "policy description", + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc...", + "value": 0.16, + "start_date": "2024-01-01T00:00:00+00:00", + "end_date": null + } + ] + } + """ if not policy or not policy.parameter_values: return None - policy_data = {} + parameter_values = [] for pv in policy.parameter_values: if not pv.parameter: continue - policy_data[pv.parameter.name] = { + parameter_values.append({ + "parameter_name": pv.parameter.name, "value": _extract_value(pv.value_json), "start_date": _format_date(pv.start_date), "end_date": _format_date(pv.end_date), - } + }) + + if not parameter_values: + return None - return policy_data if policy_data else None + return { + "name": policy.name, + "description": policy.description or "", + "parameter_values": parameter_values, + } def _extract_value(value_json: Any) -> Any: diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 14083cf..2b486f3 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -807,7 +807,6 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N raise ValueError(f"Dataset {simulation.dataset_id} not found") # Import policyengine - from policyengine.core import Simulation as PESimulation from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, @@ -815,7 +814,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N pe_model_version = uk_latest - # Get policy and dynamic + # Get policy and dynamic as PEPolicy/PEDynamic objects policy = _get_pe_policy_uk( simulation.policy_id, pe_model_version, session ) @@ -823,6 +822,13 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N simulation.dynamic_id, pe_model_version, session ) + # Convert to reform dict format for Microsimulation + # This is necessary because policyengine.core.Simulation applies + # reforms AFTER creating Microsimulation, which doesn't work + policy_reform = _pe_policy_to_reform_dict(policy) + dynamic_reform = _pe_policy_to_reform_dict(dynamic) + reform = _merge_reform_dicts(policy_reform, dynamic_reform) + # Download dataset local_path = download_dataset( dataset.filepath, supabase_url, supabase_key, storage_bucket @@ -835,15 +841,12 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N year=dataset.year, ) - # Create and run simulation + # Run simulation using Microsimulation directly with reform + # This ensures reforms are applied at construction time with logfire.span("run_simulation"): - pe_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=dynamic, + pe_output_dataset = _run_uk_economy_simulation( + pe_dataset, reform, pe_model_version, simulation_id ) - pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): @@ -853,8 +856,8 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N output_path = f"/tmp/{output_filename}" # Set filepath and save - pe_sim.output_dataset.filepath = output_path - pe_sim.output_dataset.save() + pe_output_dataset.filepath = output_path + pe_output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) @@ -869,7 +872,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N ) # Create output dataset record - output_dataset = Dataset( + output_dataset_record = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", filepath=output_filename, @@ -877,12 +880,12 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, ) - session.add(output_dataset) + session.add(output_dataset_record) session.commit() - session.refresh(output_dataset) + session.refresh(output_dataset_record) # Link to simulation - simulation.output_dataset_id = output_dataset.id + simulation.output_dataset_id = output_dataset_record.id # Mark as completed simulation.status = SimulationStatus.COMPLETED @@ -973,15 +976,15 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N raise ValueError(f"Dataset {simulation.dataset_id} not found") # Import policyengine - from policyengine.core import Simulation as PESimulation from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, + USYearData, ) pe_model_version = us_latest - # Get policy and dynamic + # Get policy and dynamic as PEPolicy/PEDynamic objects policy = _get_pe_policy_us( simulation.policy_id, pe_model_version, session ) @@ -989,6 +992,13 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N simulation.dynamic_id, pe_model_version, session ) + # Convert to reform dict format for Microsimulation + # This is necessary because policyengine.core.Simulation applies + # reforms AFTER creating Microsimulation, which doesn't work + policy_reform = _pe_policy_to_reform_dict(policy) + dynamic_reform = _pe_policy_to_reform_dict(dynamic) + reform = _merge_reform_dicts(policy_reform, dynamic_reform) + # Download dataset local_path = download_dataset( dataset.filepath, supabase_url, supabase_key, storage_bucket @@ -1001,15 +1011,12 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N year=dataset.year, ) - # Create and run simulation + # Run simulation using Microsimulation directly with reform + # This ensures reforms are applied at construction time with logfire.span("run_simulation"): - pe_sim = PESimulation( - dataset=pe_dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=dynamic, + pe_output_dataset = _run_us_economy_simulation( + pe_dataset, reform, pe_model_version, simulation_id ) - pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): @@ -1019,8 +1026,8 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N output_path = f"/tmp/{output_filename}" # Set filepath and save - pe_sim.output_dataset.filepath = output_path - pe_sim.output_dataset.save() + pe_output_dataset.filepath = output_path + pe_output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) @@ -1035,7 +1042,7 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N ) # Create output dataset record - output_dataset = Dataset( + output_dataset_record = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", filepath=output_filename, @@ -1043,12 +1050,12 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, ) - session.add(output_dataset) + session.add(output_dataset_record) session.commit() - session.refresh(output_dataset) + session.refresh(output_dataset_record) # Link to simulation - simulation.output_dataset_id = output_dataset.id + simulation.output_dataset_id = output_dataset_record.id # Mark as completed simulation.status = SimulationStatus.COMPLETED @@ -1816,6 +1823,403 @@ def _get_pe_dynamic_us(dynamic_id, model_version, session): return _get_pe_dynamic_uk(dynamic_id, model_version, session) +def _pe_policy_to_reform_dict(policy) -> dict | None: + """Convert a policyengine.core.policy.Policy to reform dict format. + + The policyengine-us/uk Microsimulation expects reforms in the format: + {"parameter.name": {"YYYY-MM-DD": value}} + + This is necessary because the policyengine.core.Simulation applies reforms + AFTER creating the Microsimulation, which doesn't work due to caching. + We need to pass the reform at Microsimulation construction time. + """ + if policy is None: + return None + + if not policy.parameter_values: + return None + + reform = {} + for pv in policy.parameter_values: + if not pv.parameter: + continue + param_name = pv.parameter.name + value = pv.value + start_date = pv.start_date + + if param_name and start_date: + # Format date as YYYY-MM-DD string + if hasattr(start_date, "strftime"): + date_str = start_date.strftime("%Y-%m-%d") + else: + date_str = str(start_date).split("T")[0] + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + return reform if reform else None + + +def _merge_reform_dicts(reform1: dict | None, reform2: dict | None) -> dict | None: + """Merge two reform dicts, with reform2 taking precedence.""" + if reform1 is None and reform2 is None: + return None + if reform1 is None: + return reform2 + if reform2 is None: + return reform1 + + merged = dict(reform1) + for param_name, dates in reform2.items(): + if param_name not in merged: + merged[param_name] = {} + merged[param_name].update(dates) + return merged + + +def _run_us_economy_simulation(pe_dataset, reform, pe_model_version, simulation_id): + """Run US economy simulation using Microsimulation directly. + + This bypasses policyengine.core.Simulation which has a bug where reforms + are applied AFTER creating Microsimulation (when it's too late). + Instead, we pass the reform dict at Microsimulation construction time. + """ + from pathlib import Path + + import numpy as np + import pandas as pd + from microdf import MicroDataFrame + from policyengine.tax_benefit_models.us.datasets import ( + PolicyEngineUSDataset, + USYearData, + ) + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_us import Microsimulation + from policyengine_us.system import system + + # Load dataset + pe_dataset.load() + year = pe_dataset.year + + # Create Microsimulation with reform applied at construction time + microsim = Microsimulation(reform=reform) + + # Build simulation from dataset using SimulationBuilder + person_df = pd.DataFrame(pe_dataset.data.person) + + # Determine column naming convention + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" + ) + marital_unit_id_col = ( + "person_marital_unit_id" + if "person_marital_unit_id" in person_df.columns + else "marital_unit_id" + ) + family_id_col = ( + "person_family_id" if "person_family_id" in person_df.columns else "family_id" + ) + spm_unit_id_col = ( + "person_spm_unit_id" + if "person_spm_unit_id" in person_df.columns + else "spm_unit_id" + ) + tax_unit_id_col = ( + "person_tax_unit_id" + if "person_tax_unit_id" in person_df.columns + else "tax_unit_id" + ) + + # Declare entities + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + builder.declare_entity("spm_unit", np.unique(person_df[spm_unit_id_col].values)) + builder.declare_entity("family", np.unique(person_df[family_id_col].values)) + builder.declare_entity("tax_unit", np.unique(person_df[tax_unit_id_col].values)) + builder.declare_entity( + "marital_unit", np.unique(person_df[marital_unit_id_col].values) + ) + + # Join persons to entities + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["spm_unit"], + person_df[spm_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["family"], + person_df[family_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["tax_unit"], + person_df[tax_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["marital_unit"], + person_df[marital_unit_id_col].values, + np.array(["member"] * len(person_df)), + ) + + microsim.build_from_populations(builder.populations) + + # Set input variables + id_columns = { + "person_id", + "household_id", + "person_household_id", + "spm_unit_id", + "person_spm_unit_id", + "family_id", + "person_family_id", + "tax_unit_id", + "person_tax_unit_id", + "marital_unit_id", + "person_marital_unit_id", + } + + for entity_name, entity_data in [ + ("person", pe_dataset.data.person), + ("household", pe_dataset.data.household), + ("spm_unit", pe_dataset.data.spm_unit), + ("family", pe_dataset.data.family), + ("tax_unit", pe_dataset.data.tax_unit), + ("marital_unit", pe_dataset.data.marital_unit), + ]: + df = pd.DataFrame(entity_data) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables and build output dataset + data = { + "person": pd.DataFrame(), + "marital_unit": pd.DataFrame(), + "family": pd.DataFrame(), + "spm_unit": pd.DataFrame(), + "tax_unit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + weight_columns = { + "person_weight", + "household_weight", + "marital_unit_weight", + "family_weight", + "spm_unit_weight", + "tax_unit_weight", + } + + # Copy ID and weight columns from input dataset + for entity in data.keys(): + input_df = pd.DataFrame(getattr(pe_dataset.data, entity)) + entity_id_col = f"{entity}_id" + entity_weight_col = f"{entity}_weight" + + if entity_id_col in input_df.columns: + data[entity][entity_id_col] = input_df[entity_id_col].values + if entity_weight_col in input_df.columns: + data[entity][entity_weight_col] = input_df[entity_weight_col].values + + # Copy person-level group ID columns + for col in person_df.columns: + if col.startswith("person_") and col.endswith("_id"): + target_col = col.replace("person_", "") + if target_col in id_columns: + data["person"][target_col] = person_df[col].values + + # Calculate non-ID, non-weight variables + for entity, variables in pe_model_version.entity_variables.items(): + for var in variables: + if var not in id_columns and var not in weight_columns: + data[entity][var] = microsim.calculate( + var, period=year, map_to=entity + ).values + + # Convert to MicroDataFrames + data["person"] = MicroDataFrame(data["person"], weights="person_weight") + data["marital_unit"] = MicroDataFrame( + data["marital_unit"], weights="marital_unit_weight" + ) + data["family"] = MicroDataFrame(data["family"], weights="family_weight") + data["spm_unit"] = MicroDataFrame(data["spm_unit"], weights="spm_unit_weight") + data["tax_unit"] = MicroDataFrame(data["tax_unit"], weights="tax_unit_weight") + data["household"] = MicroDataFrame(data["household"], weights="household_weight") + + # Create output dataset + return PolicyEngineUSDataset( + id=simulation_id, + name=pe_dataset.name, + description=pe_dataset.description, + filepath=str(Path(pe_dataset.filepath).parent / (simulation_id + ".h5")), + year=year, + is_output_dataset=True, + data=USYearData( + person=data["person"], + marital_unit=data["marital_unit"], + family=data["family"], + spm_unit=data["spm_unit"], + tax_unit=data["tax_unit"], + household=data["household"], + ), + ) + + +def _run_uk_economy_simulation(pe_dataset, reform, pe_model_version, simulation_id): + """Run UK economy simulation using Microsimulation directly. + + This bypasses policyengine.core.Simulation which has a bug where reforms + are applied AFTER creating Microsimulation (when it's too late). + Instead, we pass the reform dict at Microsimulation construction time. + """ + from pathlib import Path + + import numpy as np + import pandas as pd + from microdf import MicroDataFrame + from policyengine.tax_benefit_models.uk.datasets import ( + PolicyEngineUKDataset, + UKYearData, + ) + from policyengine_core.simulations.simulation_builder import SimulationBuilder + from policyengine_uk import Microsimulation + from policyengine_uk.system import system + + # Load dataset + pe_dataset.load() + year = pe_dataset.year + + # Create Microsimulation with reform applied at construction time + microsim = Microsimulation(reform=reform) + + # Build simulation from dataset using SimulationBuilder + person_df = pd.DataFrame(pe_dataset.data.person) + + # Determine column naming convention + benunit_id_col = ( + "person_benunit_id" + if "person_benunit_id" in person_df.columns + else "benunit_id" + ) + household_id_col = ( + "person_household_id" + if "person_household_id" in person_df.columns + else "household_id" + ) + + # Declare entities + builder = SimulationBuilder() + builder.populations = system.instantiate_entities() + + builder.declare_person_entity("person", person_df["person_id"].values) + builder.declare_entity("benunit", np.unique(person_df[benunit_id_col].values)) + builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + + # Join persons to entities + builder.join_with_persons( + builder.populations["benunit"], + person_df[benunit_id_col].values, + np.array(["member"] * len(person_df)), + ) + builder.join_with_persons( + builder.populations["household"], + person_df[household_id_col].values, + np.array(["member"] * len(person_df)), + ) + + microsim.build_from_populations(builder.populations) + + # Set input variables + id_columns = { + "person_id", + "benunit_id", + "person_benunit_id", + "household_id", + "person_household_id", + } + + for entity_name, entity_data in [ + ("person", pe_dataset.data.person), + ("benunit", pe_dataset.data.benunit), + ("household", pe_dataset.data.household), + ]: + df = pd.DataFrame(entity_data) + for column in df.columns: + if column not in id_columns and column in system.variables: + microsim.set_input(column, year, df[column].values) + + # Calculate output variables and build output dataset + data = { + "person": pd.DataFrame(), + "benunit": pd.DataFrame(), + "household": pd.DataFrame(), + } + + weight_columns = { + "person_weight", + "benunit_weight", + "household_weight", + } + + # Copy ID and weight columns from input dataset + for entity in data.keys(): + input_df = pd.DataFrame(getattr(pe_dataset.data, entity)) + entity_id_col = f"{entity}_id" + entity_weight_col = f"{entity}_weight" + + if entity_id_col in input_df.columns: + data[entity][entity_id_col] = input_df[entity_id_col].values + if entity_weight_col in input_df.columns: + data[entity][entity_weight_col] = input_df[entity_weight_col].values + + # Copy person-level group ID columns + for col in person_df.columns: + if col.startswith("person_") and col.endswith("_id"): + target_col = col.replace("person_", "") + if target_col in id_columns: + data["person"][target_col] = person_df[col].values + + # Calculate non-ID, non-weight variables + for entity, variables in pe_model_version.entity_variables.items(): + for var in variables: + if var not in id_columns and var not in weight_columns: + data[entity][var] = microsim.calculate( + var, period=year, map_to=entity + ).values + + # Convert to MicroDataFrames + data["person"] = MicroDataFrame(data["person"], weights="person_weight") + data["benunit"] = MicroDataFrame(data["benunit"], weights="benunit_weight") + data["household"] = MicroDataFrame(data["household"], weights="household_weight") + + # Create output dataset + return PolicyEngineUKDataset( + id=simulation_id, + name=pe_dataset.name, + description=pe_dataset.description, + filepath=str(Path(pe_dataset.filepath).parent / (simulation_id + ".h5")), + year=year, + is_output_dataset=True, + data=UKYearData( + person=data["person"], + benunit=data["benunit"], + household=data["household"], + ), + ) + + @app.function( image=uk_image, secrets=[db_secrets, logfire_secrets], diff --git a/test_fixtures/fixtures_policy_reform.py b/test_fixtures/fixtures_policy_reform.py new file mode 100644 index 0000000..f7534a5 --- /dev/null +++ b/test_fixtures/fixtures_policy_reform.py @@ -0,0 +1,282 @@ +"""Fixtures for policy reform conversion tests.""" + +from dataclasses import dataclass +from datetime import date, datetime +from typing import Any + + +# ============================================================================= +# Mock objects for testing _pe_policy_to_reform_dict +# ============================================================================= + + +@dataclass +class MockParameter: + """Mock policyengine.core.models.parameter.Parameter.""" + + name: str + + +@dataclass +class MockParameterValue: + """Mock policyengine.core.models.parameter_value.ParameterValue.""" + + parameter: MockParameter | None + value: Any + start_date: date | datetime | str | None + + +@dataclass +class MockPolicy: + """Mock policyengine.core.policy.Policy.""" + + parameter_values: list[MockParameterValue] | None + + +# ============================================================================= +# Test data constants +# ============================================================================= + +# Simple policy with single parameter change +SIMPLE_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=date(2024, 1, 1), + ) + ] +) + +SIMPLE_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} +} + +# Policy with multiple parameter changes +MULTI_PARAM_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.refundable.fully_refundable"), + value=True, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.income.bracket.rates.1"), + value=0.12, + start_date=date(2024, 1, 1), + ), + ] +) + +MULTI_PARAM_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, + "gov.irs.credits.ctc.refundable.fully_refundable": {"2024-01-01": True}, + "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.12}, +} + +# Policy with same parameter at different dates +MULTI_DATE_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=2500, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=date(2025, 1, 1), + ), + ] +) + +MULTI_DATE_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": { + "2024-01-01": 2500, + "2025-01-01": 3000, + } +} + +# Policy with datetime start_date (has time component) +DATETIME_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=datetime(2024, 1, 1, 12, 30, 45), + ) + ] +) + +DATETIME_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} +} + +# Policy with ISO string start_date +ISO_STRING_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date="2024-01-01T00:00:00", + ) + ] +) + +ISO_STRING_POLICY_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} +} + +# Empty policy (no parameter values) +EMPTY_POLICY = MockPolicy(parameter_values=[]) + +# None policy +NONE_POLICY = None + +# Policy with None parameter_values +NONE_PARAM_VALUES_POLICY = MockPolicy(parameter_values=None) + +# Policy with invalid entries (missing parameter or start_date) +INVALID_ENTRIES_POLICY = MockPolicy( + parameter_values=[ + MockParameterValue( + parameter=None, # Missing parameter + value=3000, + start_date=date(2024, 1, 1), + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), + value=3000, + start_date=None, # Missing start_date + ), + MockParameterValue( + parameter=MockParameter(name="gov.irs.credits.eitc.max.0"), + value=600, + start_date=date(2024, 1, 1), # This one is valid + ), + ] +) + +INVALID_ENTRIES_POLICY_EXPECTED = { + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600} +} + + +# ============================================================================= +# Test data for _merge_reform_dicts +# ============================================================================= + +REFORM_DICT_1 = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 2000}, + "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.10}, +} + +REFORM_DICT_2 = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, # Overwrites + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, # New param +} + +MERGED_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, # From reform2 + "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.10}, # From reform1 + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, # From reform2 +} + +REFORM_DICT_3 = { + "gov.irs.credits.ctc.amount.base": { + "2024-01-01": 2500, + "2025-01-01": 2700, + }, +} + +REFORM_DICT_4 = { + "gov.irs.credits.ctc.amount.base": { + "2025-01-01": 3000, # Overwrites 2025 date + "2026-01-01": 3500, # New date + }, +} + +MERGED_MULTI_DATE_EXPECTED = { + "gov.irs.credits.ctc.amount.base": { + "2024-01-01": 2500, # From reform3 + "2025-01-01": 3000, # From reform4 (overwrites) + "2026-01-01": 3500, # From reform4 (new) + }, +} + + +# ============================================================================= +# Test data for household calculation policy conversion +# ============================================================================= + +# Policy data as it comes from the API (stored in database) +HOUSEHOLD_POLICY_DATA = { + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc.amount.base", + "value": 3000, + "start_date": "2024-01-01", + }, + { + "parameter_name": "gov.irs.credits.ctc.refundable.fully_refundable", + "value": True, + "start_date": "2024-01-01", + }, + ] +} + +HOUSEHOLD_POLICY_DATA_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, + "gov.irs.credits.ctc.refundable.fully_refundable": {"2024-01-01": True}, +} + +# Policy data with ISO datetime strings +HOUSEHOLD_POLICY_DATA_DATETIME = { + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc.amount.base", + "value": 3000, + "start_date": "2024-01-01T00:00:00.000Z", + }, + ] +} + +HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED = { + "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, +} + +# Empty policy data +HOUSEHOLD_EMPTY_POLICY_DATA = {"parameter_values": []} + +# None policy data +HOUSEHOLD_NONE_POLICY_DATA = None + +# Policy data with missing fields +HOUSEHOLD_INCOMPLETE_POLICY_DATA = { + "parameter_values": [ + { + "parameter_name": None, # Missing + "value": 3000, + "start_date": "2024-01-01", + }, + { + "parameter_name": "gov.irs.credits.ctc.amount.base", + "value": 3000, + "start_date": None, # Missing + }, + { + "parameter_name": "gov.irs.credits.eitc.max.0", + "value": 600, + "start_date": "2024-01-01", # Valid + }, + ] +} + +HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED = { + "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, +} diff --git a/tests/test_policy_reform.py b/tests/test_policy_reform.py new file mode 100644 index 0000000..cfee3b8 --- /dev/null +++ b/tests/test_policy_reform.py @@ -0,0 +1,327 @@ +"""Tests for policy reform conversion logic. + +Tests the helper functions that convert policy objects to reform dict format +for use with Microsimulation. These are critical for fixing the bug where +reforms weren't being applied to economy-wide and household simulations. +""" + +import sys +from unittest.mock import MagicMock + +import pytest + +# Mock modal before importing modal_app +sys.modules["modal"] = MagicMock() + +from test_fixtures.fixtures_policy_reform import ( + DATETIME_POLICY, + DATETIME_POLICY_EXPECTED, + EMPTY_POLICY, + HOUSEHOLD_EMPTY_POLICY_DATA, + HOUSEHOLD_INCOMPLETE_POLICY_DATA, + HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED, + HOUSEHOLD_NONE_POLICY_DATA, + HOUSEHOLD_POLICY_DATA, + HOUSEHOLD_POLICY_DATA_DATETIME, + HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED, + HOUSEHOLD_POLICY_DATA_EXPECTED, + INVALID_ENTRIES_POLICY, + INVALID_ENTRIES_POLICY_EXPECTED, + ISO_STRING_POLICY, + ISO_STRING_POLICY_EXPECTED, + MERGED_EXPECTED, + MERGED_MULTI_DATE_EXPECTED, + MULTI_DATE_POLICY, + MULTI_DATE_POLICY_EXPECTED, + MULTI_PARAM_POLICY, + MULTI_PARAM_POLICY_EXPECTED, + NONE_PARAM_VALUES_POLICY, + NONE_POLICY, + REFORM_DICT_1, + REFORM_DICT_2, + REFORM_DICT_3, + REFORM_DICT_4, + SIMPLE_POLICY, + SIMPLE_POLICY_EXPECTED, +) + +# Import after mocking modal +from policyengine_api.modal_app import _merge_reform_dicts, _pe_policy_to_reform_dict + + +class TestPePolicyToReformDict: + """Tests for _pe_policy_to_reform_dict function.""" + + # ========================================================================= + # Given: Valid policy with single parameter + # ========================================================================= + + def test__given_simple_policy_with_date_object__then_returns_correct_reform_dict( + self, + ): + """Given a policy with a single parameter using date object, + then returns correctly formatted reform dict.""" + # When + result = _pe_policy_to_reform_dict(SIMPLE_POLICY) + + # Then + assert result == SIMPLE_POLICY_EXPECTED + + def test__given_policy_with_datetime_object__then_extracts_date_correctly(self): + """Given a policy with datetime start_date (has time component), + then extracts just the date part for the reform dict.""" + # When + result = _pe_policy_to_reform_dict(DATETIME_POLICY) + + # Then + assert result == DATETIME_POLICY_EXPECTED + + def test__given_policy_with_iso_string_date__then_parses_date_correctly(self): + """Given a policy with ISO string start_date, + then parses and extracts the date correctly.""" + # When + result = _pe_policy_to_reform_dict(ISO_STRING_POLICY) + + # Then + assert result == ISO_STRING_POLICY_EXPECTED + + # ========================================================================= + # Given: Policy with multiple parameters + # ========================================================================= + + def test__given_policy_with_multiple_parameters__then_includes_all_in_dict(self): + """Given a policy with multiple parameter changes, + then includes all parameters in the reform dict.""" + # When + result = _pe_policy_to_reform_dict(MULTI_PARAM_POLICY) + + # Then + assert result == MULTI_PARAM_POLICY_EXPECTED + + def test__given_policy_with_same_param_multiple_dates__then_includes_all_dates( + self, + ): + """Given a policy with the same parameter changed at different dates, + then includes all date entries for that parameter.""" + # When + result = _pe_policy_to_reform_dict(MULTI_DATE_POLICY) + + # Then + assert result == MULTI_DATE_POLICY_EXPECTED + + # ========================================================================= + # Given: Empty or None policy + # ========================================================================= + + def test__given_none_policy__then_returns_none(self): + """Given None as policy, + then returns None.""" + # When + result = _pe_policy_to_reform_dict(NONE_POLICY) + + # Then + assert result is None + + def test__given_policy_with_empty_parameter_values__then_returns_none(self): + """Given a policy with empty parameter_values list, + then returns None.""" + # When + result = _pe_policy_to_reform_dict(EMPTY_POLICY) + + # Then + assert result is None + + def test__given_policy_with_none_parameter_values__then_returns_none(self): + """Given a policy with parameter_values=None, + then returns None.""" + # When + result = _pe_policy_to_reform_dict(NONE_PARAM_VALUES_POLICY) + + # Then + assert result is None + + # ========================================================================= + # Given: Policy with invalid entries + # ========================================================================= + + def test__given_policy_with_invalid_entries__then_skips_invalid_keeps_valid(self): + """Given a policy with some invalid entries (missing parameter or date), + then skips invalid entries and keeps valid ones.""" + # When + result = _pe_policy_to_reform_dict(INVALID_ENTRIES_POLICY) + + # Then + assert result == INVALID_ENTRIES_POLICY_EXPECTED + + +class TestMergeReformDicts: + """Tests for _merge_reform_dicts function.""" + + # ========================================================================= + # Given: Two valid reform dicts + # ========================================================================= + + def test__given_two_reform_dicts__then_merges_with_second_taking_precedence(self): + """Given two reform dicts with overlapping parameters, + then merges them with the second dict taking precedence.""" + # When + result = _merge_reform_dicts(REFORM_DICT_1, REFORM_DICT_2) + + # Then + assert result == MERGED_EXPECTED + + def test__given_dicts_with_multiple_dates__then_merges_date_entries_correctly(self): + """Given reform dicts with same parameter at multiple dates, + then merges date entries correctly with second taking precedence.""" + # When + result = _merge_reform_dicts(REFORM_DICT_3, REFORM_DICT_4) + + # Then + assert result == MERGED_MULTI_DATE_EXPECTED + + # ========================================================================= + # Given: None values + # ========================================================================= + + def test__given_both_none__then_returns_none(self): + """Given both reform dicts are None, + then returns None.""" + # When + result = _merge_reform_dicts(None, None) + + # Then + assert result is None + + def test__given_first_none__then_returns_second(self): + """Given first reform dict is None, + then returns the second dict.""" + # When + result = _merge_reform_dicts(None, REFORM_DICT_1) + + # Then + assert result == REFORM_DICT_1 + + def test__given_second_none__then_returns_first(self): + """Given second reform dict is None, + then returns the first dict.""" + # When + result = _merge_reform_dicts(REFORM_DICT_1, None) + + # Then + assert result == REFORM_DICT_1 + + # ========================================================================= + # Given: Original dict should not be mutated + # ========================================================================= + + def test__given_two_dicts__then_does_not_mutate_original_dicts(self): + """Given two reform dicts, + then merging does not mutate the original dicts.""" + # Given + original_dict1 = {"param.a": {"2024-01-01": 100}} + original_dict2 = {"param.b": {"2024-01-01": 200}} + dict1_copy = dict(original_dict1) + dict2_copy = dict(original_dict2) + + # When + _merge_reform_dicts(original_dict1, original_dict2) + + # Then + assert original_dict1 == dict1_copy + assert original_dict2 == dict2_copy + + +class TestHouseholdPolicyDataConversion: + """Tests for the policy data conversion logic used in household calculations. + + This tests the conversion logic as it appears in _calculate_household_us + and _calculate_household_uk functions. + """ + + def _convert_policy_data_to_reform(self, policy_data: dict | None) -> dict | None: + """Convert policy_data (from API) to reform dict format. + + This mirrors the conversion logic in _calculate_household_us. + """ + if not policy_data or not policy_data.get("parameter_values"): + return None + + reform = {} + for pv in policy_data["parameter_values"]: + param_name = pv.get("parameter_name") + value = pv.get("value") + start_date = pv.get("start_date") + + if param_name and start_date: + # Parse ISO date string to get just the date part + if "T" in start_date: + date_str = start_date.split("T")[0] + else: + date_str = start_date + + if param_name not in reform: + reform[param_name] = {} + reform[param_name][date_str] = value + + return reform if reform else None + + # ========================================================================= + # Given: Valid policy data from API + # ========================================================================= + + def test__given_valid_policy_data__then_converts_to_reform_dict(self): + """Given valid policy data from the API, + then converts it to the correct reform dict format.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_POLICY_DATA) + + # Then + assert result == HOUSEHOLD_POLICY_DATA_EXPECTED + + def test__given_policy_data_with_datetime_strings__then_extracts_date_part(self): + """Given policy data with ISO datetime strings (with T and timezone), + then extracts just the date part.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_POLICY_DATA_DATETIME) + + # Then + assert result == HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED + + # ========================================================================= + # Given: Empty or None policy data + # ========================================================================= + + def test__given_none_policy_data__then_returns_none(self): + """Given None policy data, + then returns None.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_NONE_POLICY_DATA) + + # Then + assert result is None + + def test__given_empty_parameter_values__then_returns_none(self): + """Given policy data with empty parameter_values list, + then returns None.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_EMPTY_POLICY_DATA) + + # Then + assert result is None + + # ========================================================================= + # Given: Incomplete policy data + # ========================================================================= + + def test__given_incomplete_entries__then_skips_invalid_keeps_valid(self): + """Given policy data with some entries missing required fields, + then skips invalid entries and keeps valid ones.""" + # When + result = self._convert_policy_data_to_reform(HOUSEHOLD_INCOMPLETE_POLICY_DATA) + + # Then + assert result == HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 34c34c47a57f87d733c4f847ff565d69d2589749 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 6 Feb 2026 20:10:08 +0300 Subject: [PATCH 10/19] fix: Fix household user models; add variable default values --- .../20260204_0002_add_household_support.py | 10 +++++- ...0260206_0004_add_variable_default_value.py | 34 +++++++++++++++++++ docker-compose.yml | 14 ++++---- scripts/seed_common.py | 13 +++++++ .../models/user_household_association.py | 3 +- src/policyengine_api/models/variable.py | 3 ++ 6 files changed, 68 insertions(+), 9 deletions(-) create mode 100644 alembic/versions/20260206_0004_add_variable_default_value.py diff --git a/alembic/versions/20260204_0002_add_household_support.py b/alembic/versions/20260204_0002_add_household_support.py index beb00a0..186db37 100644 --- a/alembic/versions/20260204_0002_add_household_support.py +++ b/alembic/versions/20260204_0002_add_household_support.py @@ -57,19 +57,27 @@ def upgrade() -> None: op.create_index("idx_households_year", "households", ["year"]) # User-household associations (many-to-many for saved households) + # Note: user_id is a client-generated UUID stored in localStorage, not a foreign key op.create_table( "user_household_associations", sa.Column("id", sa.Uuid(), nullable=False), sa.Column("user_id", sa.Uuid(), nullable=False), sa.Column("household_id", sa.Uuid(), nullable=False), + sa.Column("country_id", sa.String(), nullable=False), + sa.Column("label", sa.String(), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.func.now(), + nullable=False, + ), sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(["household_id"], ["households.id"], ondelete="CASCADE"), sa.UniqueConstraint("user_id", "household_id"), ) diff --git a/alembic/versions/20260206_0004_add_variable_default_value.py b/alembic/versions/20260206_0004_add_variable_default_value.py new file mode 100644 index 0000000..2471491 --- /dev/null +++ b/alembic/versions/20260206_0004_add_variable_default_value.py @@ -0,0 +1,34 @@ +"""Add default_value to variables + +Revision ID: 0004_var_default +Revises: 0003_param_idx +Create Date: 2026-02-06 03:30:00.000000 + +This migration adds a default_value column to the variables table. +The default_value is stored as JSON to handle different types (int, float, bool, str, etc.). +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import JSON + +# revision identifiers, used by Alembic. +revision: str = "0004_var_default" +down_revision: Union[str, Sequence[str], None] = "0003_param_idx" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Add default_value column to variables table.""" + op.add_column( + "variables", + sa.Column("default_value", JSON, nullable=True), + ) + + +def downgrade() -> None: + """Remove default_value column from variables table.""" + op.drop_column("variables", "default_value") diff --git a/docker-compose.yml b/docker-compose.yml index 60e8645..60aa598 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,10 +5,10 @@ services: ports: - "${API_PORT:-8000}:${API_PORT:-8000}" environment: - SUPABASE_URL: http://supabase_kong_policyengine-api-v2:8000 + SUPABASE_URL: http://supabase_kong_policyengine-api-v2-alpha:8000 SUPABASE_KEY: ${SUPABASE_KEY} SUPABASE_SERVICE_KEY: ${SUPABASE_SERVICE_KEY} - SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2:5432/postgres + SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2-alpha:5432/postgres LOGFIRE_TOKEN: ${LOGFIRE_TOKEN} DEBUG: "false" API_PORT: ${API_PORT:-8000} @@ -19,7 +19,7 @@ services: - ./src:/app/src - ./docs/out:/app/docs/out networks: - - supabase_network_policyengine-api-v2 + - supabase_network_policyengine-api-v2-alpha healthcheck: test: ["CMD", "python", "-c", "import httpx; exit(0 if httpx.get('http://localhost:${API_PORT:-8000}/health', timeout=2).status_code == 200 else 1)"] interval: 5s @@ -31,16 +31,16 @@ services: build: . command: pytest tests/ -v environment: - SUPABASE_URL: http://supabase_kong_policyengine-api-v2:8000 + SUPABASE_URL: http://supabase_kong_policyengine-api-v2-alpha:8000 SUPABASE_KEY: ${SUPABASE_KEY} SUPABASE_SERVICE_KEY: ${SUPABASE_SERVICE_KEY} - SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2:5432/postgres + SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2-alpha:5432/postgres LOGFIRE_TOKEN: ${LOGFIRE_TOKEN} volumes: - ./src:/app/src - ./tests:/app/tests networks: - - supabase_network_policyengine-api-v2 + - supabase_network_policyengine-api-v2-alpha depends_on: api: condition: service_healthy @@ -48,5 +48,5 @@ services: - test networks: - supabase_network_policyengine-api-v2: + supabase_network_policyengine-api-v2-alpha: external: true diff --git a/scripts/seed_common.py b/scripts/seed_common.py index f6d7ab6..db248a6 100644 --- a/scripts/seed_common.py +++ b/scripts/seed_common.py @@ -7,6 +7,7 @@ import sys import warnings from datetime import datetime, timezone +from enum import Enum from pathlib import Path from uuid import uuid4 @@ -152,6 +153,16 @@ def seed_model(model_version, session, lite: bool = False): total=len(model_version.variables), ) for var in model_version.variables: + # Serialize default_value for JSON storage + default_val = var.default_value + if var.value_type is Enum: + # Enum variables: extract the member name (e.g., "SINGLE") + # NOTE: This may need to change when we determine how to properly + # add possible_values (the list of enum members) into the database. + default_val = default_val.name + elif hasattr(default_val, "isoformat"): # datetime.date + default_val = default_val.isoformat() + var_rows.append( { "id": uuid4(), @@ -162,6 +173,7 @@ def seed_model(model_version, session, lite: bool = False): if hasattr(var.data_type, "__name__") else str(var.data_type), "possible_values": None, + "default_value": json.dumps(default_val), "tax_benefit_model_version_id": db_version.id, "created_at": datetime.now(timezone.utc), } @@ -179,6 +191,7 @@ def seed_model(model_version, session, lite: bool = False): "description", "data_type", "possible_values", + "default_value", "tax_benefit_model_version_id", "created_at", ], diff --git a/src/policyengine_api/models/user_household_association.py b/src/policyengine_api/models/user_household_association.py index 208279a..9a961cc 100644 --- a/src/policyengine_api/models/user_household_association.py +++ b/src/policyengine_api/models/user_household_association.py @@ -9,7 +9,8 @@ class UserHouseholdAssociationBase(SQLModel): """Base association fields.""" - user_id: UUID = Field(foreign_key="users.id", index=True) + # user_id is a client-generated UUID stored in localStorage, not a foreign key + user_id: UUID = Field(index=True) household_id: UUID = Field(foreign_key="households.id", index=True) country_id: str label: str | None = None diff --git a/src/policyengine_api/models/variable.py b/src/policyengine_api/models/variable.py index f163577..16c83b0 100644 --- a/src/policyengine_api/models/variable.py +++ b/src/policyengine_api/models/variable.py @@ -18,6 +18,9 @@ class VariableBase(SQLModel): possible_values: str | None = Field( default=None, sa_column=Column(JSON) ) # Store as JSON list + default_value: str | None = Field( + default=None, sa_column=Column(JSON) + ) # Store as JSON (handles int, float, bool, str, etc.) tax_benefit_model_version_id: UUID = Field( foreign_key="tax_benefit_model_versions.id" ) From 013304f2047d50c94f60d1f328d34c94ee1a9e39 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Fri, 6 Feb 2026 21:47:51 +0300 Subject: [PATCH 11/19] refactor: Simplify seed scripts to use policyengine.py for default_value serialization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove Enum/date serialization from seed_common.py since policyengine.py now pre-serializes default_value for JSON compatibility - Change default_value type from `str | None` to `Any` in Variable model since it stores JSON values (bool, int, float, str) Depends on policyengine.py feat/add-variable-default-value branch šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- scripts/seed_common.py | 22 ++++++++++------------ src/policyengine_api/models/variable.py | 4 ++-- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/scripts/seed_common.py b/scripts/seed_common.py index db248a6..49797cb 100644 --- a/scripts/seed_common.py +++ b/scripts/seed_common.py @@ -7,7 +7,6 @@ import sys import warnings from datetime import datetime, timezone -from enum import Enum from pathlib import Path from uuid import uuid4 @@ -81,6 +80,11 @@ def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): def seed_model(model_version, session, lite: bool = False): """Seed a tax-benefit model with its variables and parameters. + Args: + model_version: The policyengine package model version + session: Database session + lite: If True, skip state-level parameters + Returns the TaxBenefitModelVersion that was created or found. """ from policyengine_api.models import ( @@ -153,16 +157,10 @@ def seed_model(model_version, session, lite: bool = False): total=len(model_version.variables), ) for var in model_version.variables: - # Serialize default_value for JSON storage - default_val = var.default_value - if var.value_type is Enum: - # Enum variables: extract the member name (e.g., "SINGLE") - # NOTE: This may need to change when we determine how to properly - # add possible_values (the list of enum members) into the database. - default_val = default_val.name - elif hasattr(default_val, "isoformat"): # datetime.date - default_val = default_val.isoformat() - + # default_value is pre-serialized by policyengine.py: + # - Enum values are converted to their name (e.g., "SINGLE") + # - datetime.date values are converted to ISO format + # - Primitives (bool, int, float, str) are kept as-is var_rows.append( { "id": uuid4(), @@ -173,7 +171,7 @@ def seed_model(model_version, session, lite: bool = False): if hasattr(var.data_type, "__name__") else str(var.data_type), "possible_values": None, - "default_value": json.dumps(default_val), + "default_value": json.dumps(var.default_value), "tax_benefit_model_version_id": db_version.id, "created_at": datetime.now(timezone.utc), } diff --git a/src/policyengine_api/models/variable.py b/src/policyengine_api/models/variable.py index 16c83b0..eeebddc 100644 --- a/src/policyengine_api/models/variable.py +++ b/src/policyengine_api/models/variable.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 from sqlmodel import JSON, Column, Field, Relationship, SQLModel @@ -18,7 +18,7 @@ class VariableBase(SQLModel): possible_values: str | None = Field( default=None, sa_column=Column(JSON) ) # Store as JSON list - default_value: str | None = Field( + default_value: Any = Field( default=None, sa_column=Column(JSON) ) # Store as JSON (handles int, float, bool, str, etc.) tax_benefit_model_version_id: UUID = Field( From a97f4736947955099c63615b66a2c9e1cf7940ca Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Feb 2026 02:07:23 +0300 Subject: [PATCH 12/19] fix: FINALLY use the ACTUAL Alembic script to generate migrations --- .../versions/20260204_0001_initial_schema.py | 537 ------------------ .../20260204_0002_add_household_support.py | 178 ------ ...60204_0003_add_parameter_values_indexes.py | 52 -- ...0260206_0004_add_variable_default_value.py | 34 -- .../20260207_36f9d434e95b_initial_schema.py | 321 +++++++++++ ...0207_f419b5f4acba_add_household_support.py | 81 +++ 6 files changed, 402 insertions(+), 801 deletions(-) delete mode 100644 alembic/versions/20260204_0001_initial_schema.py delete mode 100644 alembic/versions/20260204_0002_add_household_support.py delete mode 100644 alembic/versions/20260204_0003_add_parameter_values_indexes.py delete mode 100644 alembic/versions/20260206_0004_add_variable_default_value.py create mode 100644 alembic/versions/20260207_36f9d434e95b_initial_schema.py create mode 100644 alembic/versions/20260207_f419b5f4acba_add_household_support.py diff --git a/alembic/versions/20260204_0001_initial_schema.py b/alembic/versions/20260204_0001_initial_schema.py deleted file mode 100644 index 273124a..0000000 --- a/alembic/versions/20260204_0001_initial_schema.py +++ /dev/null @@ -1,537 +0,0 @@ -"""Initial schema (main branch state) - -Revision ID: 0001_initial -Revises: -Create Date: 2026-02-04 - -This migration creates all base tables for the PolicyEngine API as they -exist on the main branch, BEFORE the household CRUD changes. -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "0001_initial" -down_revision: Union[str, Sequence[str], None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Create all tables as they exist on main branch.""" - # ======================================================================== - # TIER 1: Tables with no foreign key dependencies - # ======================================================================== - - # Tax benefit models (e.g., "uk", "us") - op.create_table( - "tax_benefit_models", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - ) - - # Users - op.create_table( - "users", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("first_name", sa.String(), nullable=False), - sa.Column("last_name", sa.String(), nullable=False), - sa.Column("email", sa.String(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("email"), - ) - op.create_index("ix_users_email", "users", ["email"]) - - # Policies (reform definitions) - op.create_table( - "policies", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - ) - - # Dynamics (behavioral response definitions) - op.create_table( - "dynamics", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - ) - - # ======================================================================== - # TIER 2: Tables depending on tier 1 - # ======================================================================== - - # Tax benefit model versions - op.create_table( - "tax_benefit_model_versions", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("model_id", sa.Uuid(), nullable=False), - sa.Column("version", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["model_id"], ["tax_benefit_models.id"]), - ) - - # Datasets (h5 files in storage) - op.create_table( - "datasets", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column("filepath", sa.String(), nullable=False), - sa.Column("year", sa.Integer(), nullable=False), - sa.Column("is_output_dataset", sa.Boolean(), nullable=False, default=False), - sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), - ) - - # ======================================================================== - # TIER 3: Tables depending on tier 2 - # ======================================================================== - - # Parameters (tax-benefit system parameters) - op.create_table( - "parameters", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("label", sa.String(), nullable=True), - sa.Column("description", sa.String(), nullable=True), - sa.Column("data_type", sa.String(), nullable=True), - sa.Column("unit", sa.String(), nullable=True), - sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint( - ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] - ), - ) - - # Variables (tax-benefit system variables) - op.create_table( - "variables", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column("data_type", sa.String(), nullable=True), - sa.Column("possible_values", sa.JSON(), nullable=True), - sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint( - ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] - ), - ) - - # Dataset versions - op.create_table( - "dataset_versions", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("name", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=False), - sa.Column("dataset_id", sa.Uuid(), nullable=False), - sa.Column("tax_benefit_model_id", sa.Uuid(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), - sa.ForeignKeyConstraint(["tax_benefit_model_id"], ["tax_benefit_models.id"]), - ) - - # ======================================================================== - # TIER 4: Tables depending on tier 3 - # ======================================================================== - - # Parameter values (policy/dynamic parameter modifications) - op.create_table( - "parameter_values", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("parameter_id", sa.Uuid(), nullable=False), - sa.Column("value_json", sa.JSON(), nullable=True), - sa.Column("start_date", sa.DateTime(timezone=True), nullable=False), - sa.Column("end_date", sa.DateTime(timezone=True), nullable=True), - sa.Column("policy_id", sa.Uuid(), nullable=True), - sa.Column("dynamic_id", sa.Uuid(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["parameter_id"], ["parameters.id"]), - sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), - sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), - ) - - # Simulations (economy calculations) - NOTE: No household support yet - op.create_table( - "simulations", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("dataset_id", sa.Uuid(), nullable=False), # Required in main - sa.Column("policy_id", sa.Uuid(), nullable=True), - sa.Column("dynamic_id", sa.Uuid(), nullable=True), - sa.Column("tax_benefit_model_version_id", sa.Uuid(), nullable=False), - sa.Column("output_dataset_id", sa.Uuid(), nullable=True), - sa.Column("status", sa.String(), nullable=False, default="pending"), - sa.Column("error_message", sa.String(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["dataset_id"], ["datasets.id"]), - sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), - sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), - sa.ForeignKeyConstraint( - ["tax_benefit_model_version_id"], ["tax_benefit_model_versions.id"] - ), - sa.ForeignKeyConstraint(["output_dataset_id"], ["datasets.id"]), - ) - - # Household jobs (async household calculations) - legacy approach - op.create_table( - "household_jobs", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("tax_benefit_model_name", sa.String(), nullable=False), - sa.Column("request_data", sa.JSON(), nullable=False), - sa.Column("policy_id", sa.Uuid(), nullable=True), - sa.Column("dynamic_id", sa.Uuid(), nullable=True), - sa.Column("status", sa.String(), nullable=False, default="pending"), - sa.Column("error_message", sa.String(), nullable=True), - sa.Column("result", sa.JSON(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column("started_at", sa.DateTime(timezone=True), nullable=True), - sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["policy_id"], ["policies.id"]), - sa.ForeignKeyConstraint(["dynamic_id"], ["dynamics.id"]), - ) - - # ======================================================================== - # TIER 5: Tables depending on simulations - # ======================================================================== - - # Reports (analysis reports) - NOTE: No report_type yet - op.create_table( - "reports", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("label", sa.String(), nullable=False), - sa.Column("description", sa.String(), nullable=True), - sa.Column("user_id", sa.Uuid(), nullable=True), - sa.Column("markdown", sa.Text(), nullable=True), - sa.Column("parent_report_id", sa.Uuid(), nullable=True), - sa.Column("status", sa.String(), nullable=False, default="pending"), - sa.Column("error_message", sa.String(), nullable=True), - sa.Column("baseline_simulation_id", sa.Uuid(), nullable=True), - sa.Column("reform_simulation_id", sa.Uuid(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["user_id"], ["users.id"]), - sa.ForeignKeyConstraint(["parent_report_id"], ["reports.id"]), - sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), - ) - - # Aggregates (single-simulation aggregate outputs) - op.create_table( - "aggregates", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("simulation_id", sa.Uuid(), nullable=False), - sa.Column("user_id", sa.Uuid(), nullable=True), - sa.Column("report_id", sa.Uuid(), nullable=True), - sa.Column("variable", sa.String(), nullable=False), - sa.Column("aggregate_type", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=True), - sa.Column("filter_config", sa.JSON(), nullable=False, default={}), - sa.Column("status", sa.String(), nullable=False, default="pending"), - sa.Column("error_message", sa.String(), nullable=True), - sa.Column("result", sa.Float(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["user_id"], ["users.id"]), - sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), - ) - - # Change aggregates (baseline vs reform comparison) - op.create_table( - "change_aggregates", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), - sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), - sa.Column("user_id", sa.Uuid(), nullable=True), - sa.Column("report_id", sa.Uuid(), nullable=True), - sa.Column("variable", sa.String(), nullable=False), - sa.Column("aggregate_type", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=True), - sa.Column("filter_config", sa.JSON(), nullable=False, default={}), - sa.Column("change_geq", sa.Float(), nullable=True), - sa.Column("change_leq", sa.Float(), nullable=True), - sa.Column("status", sa.String(), nullable=False, default="pending"), - sa.Column("error_message", sa.String(), nullable=True), - sa.Column("result", sa.Float(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["user_id"], ["users.id"]), - sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), - ) - - # Decile impacts - op.create_table( - "decile_impacts", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), - sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), - sa.Column("report_id", sa.Uuid(), nullable=True), - sa.Column("income_variable", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=True), - sa.Column("decile", sa.Integer(), nullable=False), - sa.Column("quantiles", sa.Integer(), nullable=False, default=10), - sa.Column("baseline_mean", sa.Float(), nullable=True), - sa.Column("reform_mean", sa.Float(), nullable=True), - sa.Column("absolute_change", sa.Float(), nullable=True), - sa.Column("relative_change", sa.Float(), nullable=True), - sa.Column("count_better_off", sa.Float(), nullable=True), - sa.Column("count_worse_off", sa.Float(), nullable=True), - sa.Column("count_no_change", sa.Float(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), - ) - - # Program statistics - op.create_table( - "program_statistics", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("baseline_simulation_id", sa.Uuid(), nullable=False), - sa.Column("reform_simulation_id", sa.Uuid(), nullable=False), - sa.Column("report_id", sa.Uuid(), nullable=True), - sa.Column("program_name", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=False), - sa.Column("is_tax", sa.Boolean(), nullable=False, default=False), - sa.Column("baseline_total", sa.Float(), nullable=True), - sa.Column("reform_total", sa.Float(), nullable=True), - sa.Column("change", sa.Float(), nullable=True), - sa.Column("baseline_count", sa.Float(), nullable=True), - sa.Column("reform_count", sa.Float(), nullable=True), - sa.Column("winners", sa.Float(), nullable=True), - sa.Column("losers", sa.Float(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["baseline_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["reform_simulation_id"], ["simulations.id"]), - sa.ForeignKeyConstraint(["report_id"], ["reports.id"]), - ) - - # Poverty - op.create_table( - "poverty", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("simulation_id", sa.Uuid(), nullable=False), - sa.Column("report_id", sa.Uuid(), nullable=True), - sa.Column("poverty_type", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=False, default="person"), - sa.Column("filter_variable", sa.String(), nullable=True), - sa.Column("headcount", sa.Float(), nullable=True), - sa.Column("total_population", sa.Float(), nullable=True), - sa.Column("rate", sa.Float(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint( - ["simulation_id"], ["simulations.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["report_id"], ["reports.id"], ondelete="CASCADE"), - ) - op.create_index("idx_poverty_simulation_id", "poverty", ["simulation_id"]) - op.create_index("idx_poverty_report_id", "poverty", ["report_id"]) - - # Inequality - op.create_table( - "inequality", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("simulation_id", sa.Uuid(), nullable=False), - sa.Column("report_id", sa.Uuid(), nullable=True), - sa.Column("income_variable", sa.String(), nullable=False), - sa.Column("entity", sa.String(), nullable=False, default="household"), - sa.Column("gini", sa.Float(), nullable=True), - sa.Column("top_10_share", sa.Float(), nullable=True), - sa.Column("top_1_share", sa.Float(), nullable=True), - sa.Column("bottom_50_share", sa.Float(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint( - ["simulation_id"], ["simulations.id"], ondelete="CASCADE" - ), - sa.ForeignKeyConstraint(["report_id"], ["reports.id"], ondelete="CASCADE"), - ) - op.create_index("idx_inequality_simulation_id", "inequality", ["simulation_id"]) - op.create_index("idx_inequality_report_id", "inequality", ["report_id"]) - - -def downgrade() -> None: - """Drop all tables in reverse order.""" - # Tier 5 - op.drop_index("idx_inequality_report_id", "inequality") - op.drop_index("idx_inequality_simulation_id", "inequality") - op.drop_table("inequality") - op.drop_index("idx_poverty_report_id", "poverty") - op.drop_index("idx_poverty_simulation_id", "poverty") - op.drop_table("poverty") - op.drop_table("program_statistics") - op.drop_table("decile_impacts") - op.drop_table("change_aggregates") - op.drop_table("aggregates") - op.drop_table("reports") - - # Tier 4 - op.drop_table("household_jobs") - op.drop_table("simulations") - op.drop_table("parameter_values") - - # Tier 3 - op.drop_table("dataset_versions") - op.drop_table("variables") - op.drop_table("parameters") - - # Tier 2 - op.drop_table("datasets") - op.drop_table("tax_benefit_model_versions") - - # Tier 1 - op.drop_table("dynamics") - op.drop_table("policies") - op.drop_index("ix_users_email", "users") - op.drop_table("users") - op.drop_table("tax_benefit_models") diff --git a/alembic/versions/20260204_0002_add_household_support.py b/alembic/versions/20260204_0002_add_household_support.py deleted file mode 100644 index 186db37..0000000 --- a/alembic/versions/20260204_0002_add_household_support.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Add household CRUD and impact analysis support - -Revision ID: 0002_household -Revises: 0001_initial -Create Date: 2026-02-04 - -This migration adds support for: -- Storing household definitions (households table) -- User-household associations for saved households -- Household-based simulations (adds household_id to simulations) -- Household impact reports (adds report_type to reports) -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "0002_household" -down_revision: Union[str, Sequence[str], None] = "0001_initial" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Add household support.""" - # ======================================================================== - # NEW TABLES - # ======================================================================== - - # Households (stored household definitions) - op.create_table( - "households", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("tax_benefit_model_name", sa.String(), nullable=False), - sa.Column("year", sa.Integer(), nullable=False), - sa.Column("label", sa.String(), nullable=True), - sa.Column("household_data", sa.JSON(), nullable=False), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - ) - op.create_index( - "idx_households_model_name", "households", ["tax_benefit_model_name"] - ) - op.create_index("idx_households_year", "households", ["year"]) - - # User-household associations (many-to-many for saved households) - # Note: user_id is a client-generated UUID stored in localStorage, not a foreign key - op.create_table( - "user_household_associations", - sa.Column("id", sa.Uuid(), nullable=False), - sa.Column("user_id", sa.Uuid(), nullable=False), - sa.Column("household_id", sa.Uuid(), nullable=False), - sa.Column("country_id", sa.String(), nullable=False), - sa.Column("label", sa.String(), nullable=True), - sa.Column( - "created_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.Column( - "updated_at", - sa.DateTime(timezone=True), - server_default=sa.func.now(), - nullable=False, - ), - sa.PrimaryKeyConstraint("id"), - sa.ForeignKeyConstraint(["household_id"], ["households.id"], ondelete="CASCADE"), - sa.UniqueConstraint("user_id", "household_id"), - ) - op.create_index( - "idx_user_household_user", "user_household_associations", ["user_id"] - ) - op.create_index( - "idx_user_household_household", "user_household_associations", ["household_id"] - ) - - # ======================================================================== - # MODIFY SIMULATIONS TABLE - # ======================================================================== - - # Add simulation_type column (economy vs household) - op.add_column( - "simulations", - sa.Column( - "simulation_type", - sa.String(), - nullable=False, - server_default="economy", - ), - ) - - # Add household_id column (for household simulations) - op.add_column( - "simulations", - sa.Column("household_id", sa.Uuid(), nullable=True), - ) - op.create_foreign_key( - "fk_simulations_household_id", - "simulations", - "households", - ["household_id"], - ["id"], - ) - - # Add household_result column (stores household calculation results) - op.add_column( - "simulations", - sa.Column("household_result", sa.JSON(), nullable=True), - ) - - # Make dataset_id nullable (household simulations don't need a dataset) - op.alter_column( - "simulations", - "dataset_id", - existing_type=sa.Uuid(), - nullable=True, - ) - - # ======================================================================== - # MODIFY REPORTS TABLE - # ======================================================================== - - # Add report_type column (economy_comparison, household_impact, etc.) - op.add_column( - "reports", - sa.Column("report_type", sa.String(), nullable=True), - ) - - -def downgrade() -> None: - """Remove household support.""" - # ======================================================================== - # REVERT REPORTS TABLE - # ======================================================================== - op.drop_column("reports", "report_type") - - # ======================================================================== - # REVERT SIMULATIONS TABLE - # ======================================================================== - - # Make dataset_id required again - op.alter_column( - "simulations", - "dataset_id", - existing_type=sa.Uuid(), - nullable=False, - ) - - # Remove household columns - op.drop_column("simulations", "household_result") - op.drop_constraint("fk_simulations_household_id", "simulations", type_="foreignkey") - op.drop_column("simulations", "household_id") - op.drop_column("simulations", "simulation_type") - - # ======================================================================== - # DROP NEW TABLES - # ======================================================================== - op.drop_index("idx_user_household_household", "user_household_associations") - op.drop_index("idx_user_household_user", "user_household_associations") - op.drop_table("user_household_associations") - - op.drop_index("idx_households_year", "households") - op.drop_index("idx_households_model_name", "households") - op.drop_table("households") diff --git a/alembic/versions/20260204_0003_add_parameter_values_indexes.py b/alembic/versions/20260204_0003_add_parameter_values_indexes.py deleted file mode 100644 index 53518cf..0000000 --- a/alembic/versions/20260204_0003_add_parameter_values_indexes.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Add parameter_values indexes - -Revision ID: 0003_param_idx -Revises: 0002_household -Create Date: 2026-02-04 02:20:00.000000 - -This migration adds performance indexes to the parameter_values table -for optimizing common query patterns. -""" - -from typing import Sequence, Union - -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "0003_param_idx" -down_revision: Union[str, Sequence[str], None] = "0002_household" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Add performance indexes to parameter_values.""" - # Composite index for the most common query pattern (filtering by both) - op.create_index( - "idx_parameter_values_parameter_policy", - "parameter_values", - ["parameter_id", "policy_id"], - ) - - # Single index on policy_id for filtering by policy alone - op.create_index( - "idx_parameter_values_policy", - "parameter_values", - ["policy_id"], - ) - - # Partial index for baseline values (policy_id IS NULL) - # This optimizes the common "get current law values" query - op.create_index( - "idx_parameter_values_baseline", - "parameter_values", - ["parameter_id"], - postgresql_where="policy_id IS NULL", - ) - - -def downgrade() -> None: - """Remove parameter_values indexes.""" - op.drop_index("idx_parameter_values_baseline", "parameter_values") - op.drop_index("idx_parameter_values_policy", "parameter_values") - op.drop_index("idx_parameter_values_parameter_policy", "parameter_values") diff --git a/alembic/versions/20260206_0004_add_variable_default_value.py b/alembic/versions/20260206_0004_add_variable_default_value.py deleted file mode 100644 index 2471491..0000000 --- a/alembic/versions/20260206_0004_add_variable_default_value.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Add default_value to variables - -Revision ID: 0004_var_default -Revises: 0003_param_idx -Create Date: 2026-02-06 03:30:00.000000 - -This migration adds a default_value column to the variables table. -The default_value is stored as JSON to handle different types (int, float, bool, str, etc.). -""" - -from typing import Sequence, Union - -import sqlalchemy as sa -from alembic import op -from sqlalchemy.dialects.postgresql import JSON - -# revision identifiers, used by Alembic. -revision: str = "0004_var_default" -down_revision: Union[str, Sequence[str], None] = "0003_param_idx" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - """Add default_value column to variables table.""" - op.add_column( - "variables", - sa.Column("default_value", JSON, nullable=True), - ) - - -def downgrade() -> None: - """Remove default_value column from variables table.""" - op.drop_column("variables", "default_value") diff --git a/alembic/versions/20260207_36f9d434e95b_initial_schema.py b/alembic/versions/20260207_36f9d434e95b_initial_schema.py new file mode 100644 index 0000000..0dce2e7 --- /dev/null +++ b/alembic/versions/20260207_36f9d434e95b_initial_schema.py @@ -0,0 +1,321 @@ +"""initial schema + +Revision ID: 36f9d434e95b +Revises: +Create Date: 2026-02-07 01:52:16.497121 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '36f9d434e95b' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dynamics', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('policies', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('tax_benefit_models', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('users', + sa.Column('first_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('last_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + op.create_table('datasets', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('filepath', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('year', sa.Integer(), nullable=False), + sa.Column('is_output_dataset', sa.Boolean(), nullable=False), + sa.Column('tax_benefit_model_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_id'], ['tax_benefit_models.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('household_jobs', + sa.Column('tax_benefit_model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('request_data', sa.JSON(), nullable=True), + sa.Column('policy_id', sa.Uuid(), nullable=True), + sa.Column('dynamic_id', sa.Uuid(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='householdjobstatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('result', sa.JSON(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('started_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['dynamic_id'], ['dynamics.id'], ), + sa.ForeignKeyConstraint(['policy_id'], ['policies.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('tax_benefit_model_versions', + sa.Column('model_id', sa.Uuid(), nullable=False), + sa.Column('version', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['model_id'], ['tax_benefit_models.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('dataset_versions', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('dataset_id', sa.Uuid(), nullable=False), + sa.Column('tax_benefit_model_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['tax_benefit_model_id'], ['tax_benefit_models.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('parameters', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('data_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('simulations', + sa.Column('dataset_id', sa.Uuid(), nullable=False), + sa.Column('policy_id', sa.Uuid(), nullable=True), + sa.Column('dynamic_id', sa.Uuid(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('output_dataset_id', sa.Uuid(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='simulationstatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('started_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['dynamic_id'], ['dynamics.id'], ), + sa.ForeignKeyConstraint(['output_dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['policy_id'], ['policies.id'], ), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('variables', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('data_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('possible_values', sa.JSON(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('parameter_values', + sa.Column('parameter_id', sa.Uuid(), nullable=False), + sa.Column('value_json', sa.JSON(), nullable=True), + sa.Column('start_date', sa.DateTime(), nullable=False), + sa.Column('end_date', sa.DateTime(), nullable=True), + sa.Column('policy_id', sa.Uuid(), nullable=True), + sa.Column('dynamic_id', sa.Uuid(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['dynamic_id'], ['dynamics.id'], ), + sa.ForeignKeyConstraint(['parameter_id'], ['parameters.id'], ), + sa.ForeignKeyConstraint(['policy_id'], ['policies.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('reports', + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('user_id', sa.Uuid(), nullable=True), + sa.Column('markdown', sa.Text(), nullable=True), + sa.Column('parent_report_id', sa.Uuid(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='reportstatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=True), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['parent_report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('aggregates', + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=True), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('aggregate_type', sa.Enum('SUM', 'MEAN', 'COUNT', name='aggregatetype'), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('filter_config', sa.JSON(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='aggregatestatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('result', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('change_aggregates', + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=False), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=True), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('aggregate_type', sa.Enum('SUM', 'MEAN', 'COUNT', name='changeaggregatetype'), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('filter_config', sa.JSON(), nullable=True), + sa.Column('change_geq', sa.Float(), nullable=True), + sa.Column('change_leq', sa.Float(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='changeaggregatestatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('result', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('decile_impacts', + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=False), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('income_variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('decile', sa.Integer(), nullable=False), + sa.Column('quantiles', sa.Integer(), nullable=False), + sa.Column('baseline_mean', sa.Float(), nullable=True), + sa.Column('reform_mean', sa.Float(), nullable=True), + sa.Column('absolute_change', sa.Float(), nullable=True), + sa.Column('relative_change', sa.Float(), nullable=True), + sa.Column('count_better_off', sa.Float(), nullable=True), + sa.Column('count_worse_off', sa.Float(), nullable=True), + sa.Column('count_no_change', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('inequality', + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('income_variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('gini', sa.Float(), nullable=True), + sa.Column('top_10_share', sa.Float(), nullable=True), + sa.Column('top_1_share', sa.Float(), nullable=True), + sa.Column('bottom_50_share', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('poverty', + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('poverty_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('filter_variable', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('headcount', sa.Float(), nullable=True), + sa.Column('total_population', sa.Float(), nullable=True), + sa.Column('rate', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('program_statistics', + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=False), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('program_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('is_tax', sa.Boolean(), nullable=False), + sa.Column('baseline_total', sa.Float(), nullable=True), + sa.Column('reform_total', sa.Float(), nullable=True), + sa.Column('change', sa.Float(), nullable=True), + sa.Column('baseline_count', sa.Float(), nullable=True), + sa.Column('reform_count', sa.Float(), nullable=True), + sa.Column('winners', sa.Float(), nullable=True), + sa.Column('losers', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('program_statistics') + op.drop_table('poverty') + op.drop_table('inequality') + op.drop_table('decile_impacts') + op.drop_table('change_aggregates') + op.drop_table('aggregates') + op.drop_table('reports') + op.drop_table('parameter_values') + op.drop_table('variables') + op.drop_table('simulations') + op.drop_table('parameters') + op.drop_table('dataset_versions') + op.drop_table('tax_benefit_model_versions') + op.drop_table('household_jobs') + op.drop_table('datasets') + op.drop_index(op.f('ix_users_email'), table_name='users') + op.drop_table('users') + op.drop_table('tax_benefit_models') + op.drop_table('policies') + op.drop_table('dynamics') + # ### end Alembic commands ### diff --git a/alembic/versions/20260207_f419b5f4acba_add_household_support.py b/alembic/versions/20260207_f419b5f4acba_add_household_support.py new file mode 100644 index 0000000..cef65f3 --- /dev/null +++ b/alembic/versions/20260207_f419b5f4acba_add_household_support.py @@ -0,0 +1,81 @@ +"""add household support + +Revision ID: f419b5f4acba +Revises: 36f9d434e95b +Create Date: 2026-02-07 01:56:31.064511 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'f419b5f4acba' +down_revision: Union[str, Sequence[str], None] = '36f9d434e95b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('households', + sa.Column('tax_benefit_model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('year', sa.Integer(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('household_data', sa.JSON(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('user_household_associations', + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('household_id', sa.Uuid(), nullable=False), + sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['household_id'], ['households.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_household_associations_household_id'), 'user_household_associations', ['household_id'], unique=False) + op.create_index(op.f('ix_user_household_associations_user_id'), 'user_household_associations', ['user_id'], unique=False) + op.add_column('reports', sa.Column('report_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + # Create enum type first + simulationtype = postgresql.ENUM('HOUSEHOLD', 'ECONOMY', name='simulationtype', create_type=False) + simulationtype.create(op.get_bind(), checkfirst=True) + op.add_column('simulations', sa.Column('simulation_type', sa.Enum('HOUSEHOLD', 'ECONOMY', name='simulationtype', create_type=False), nullable=False)) + op.add_column('simulations', sa.Column('household_id', sa.Uuid(), nullable=True)) + op.add_column('simulations', sa.Column('household_result', postgresql.JSON(astext_type=sa.Text()), nullable=True)) + op.alter_column('simulations', 'dataset_id', + existing_type=sa.UUID(), + nullable=True) + op.create_foreign_key(None, 'simulations', 'households', ['household_id'], ['id']) + op.add_column('variables', sa.Column('default_value', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('variables', 'default_value') + op.drop_constraint(None, 'simulations', type_='foreignkey') + op.alter_column('simulations', 'dataset_id', + existing_type=sa.UUID(), + nullable=False) + op.drop_column('simulations', 'household_result') + op.drop_column('simulations', 'household_id') + op.drop_column('simulations', 'simulation_type') + # Drop enum type + postgresql.ENUM('HOUSEHOLD', 'ECONOMY', name='simulationtype').drop(op.get_bind(), checkfirst=True) + op.drop_column('reports', 'report_type') + op.drop_index(op.f('ix_user_household_associations_user_id'), table_name='user_household_associations') + op.drop_index(op.f('ix_user_household_associations_household_id'), table_name='user_household_associations') + op.drop_table('user_household_associations') + op.drop_table('households') + # ### end Alembic commands ### From bdebc9e0f169e123fdfc6bb04d9898ff528f52d6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Feb 2026 03:42:40 +0300 Subject: [PATCH 13/19] test: Add Variable model tests for default_value field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add tests for Variable with int, float, bool, and string default values - Add test for null default_value handling - Add test for Household model creation šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/test_models.py | 89 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 0f84140..e3a83d9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,9 +6,11 @@ AggregateOutput, AggregateType, Dataset, + Household, Policy, Simulation, SimulationStatus, + Variable, ) @@ -66,3 +68,90 @@ def test_aggregate_output_creation(): assert output.simulation_id == simulation_id assert output.aggregate_type == AggregateType.SUM assert output.result is None + + +def test_variable_creation_with_default_value(): + """Test variable model creation with default_value field.""" + model_version_id = uuid4() + variable = Variable( + name="age", + entity="person", + description="Age of the person", + data_type="int", + default_value=40, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.name == "age" + assert variable.entity == "person" + assert variable.data_type == "int" + assert variable.default_value == 40 + assert variable.id is not None + + +def test_variable_with_float_default_value(): + """Test variable model with float default value.""" + model_version_id = uuid4() + variable = Variable( + name="employment_income", + entity="person", + data_type="float", + default_value=0.0, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value == 0.0 + + +def test_variable_with_bool_default_value(): + """Test variable model with boolean default value.""" + model_version_id = uuid4() + variable = Variable( + name="is_disabled", + entity="person", + data_type="bool", + default_value=False, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value is False + + +def test_variable_with_string_default_value(): + """Test variable model with string default value (enum).""" + model_version_id = uuid4() + variable = Variable( + name="state_name", + entity="household", + data_type="Enum", + default_value="CA", + possible_values=["CA", "NY", "TX"], + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value == "CA" + assert variable.possible_values == ["CA", "NY", "TX"] + + +def test_variable_with_null_default_value(): + """Test variable model with null default value.""" + model_version_id = uuid4() + variable = Variable( + name="optional_field", + entity="person", + data_type="str", + default_value=None, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value is None + + +def test_household_creation(): + """Test household model creation.""" + household = Household( + tax_benefit_model_name="policyengine_us", + year=2024, + label="Test household", + household_data={"people": [{"age": 30}], "household": {}}, + ) + assert household.household_data == {"people": [{"age": 30}], "household": {}} + assert household.label == "Test household" + assert household.tax_benefit_model_name == "policyengine_us" + assert household.year == 2024 + assert household.id is not None From b15c26842cab6b984e6511a43dd8a7644ee19940 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Sat, 7 Feb 2026 03:53:10 +0300 Subject: [PATCH 14/19] fix: Convert string report_id to UUID in _run_local_household_impact MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit session.get(Report, report_id) expects a UUID, but report_id was passed as a string. This caused 'str' object has no attribute 'hex' errors. šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/policyengine_api/api/household_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py index 5f36fda..d321be4 100644 --- a/src/policyengine_api/api/household_analysis.py +++ b/src/policyengine_api/api/household_analysis.py @@ -376,7 +376,7 @@ def _run_local_household_impact(report_id: str, session: Session) -> None: locally (agent_use_modal=False). This mirrors the economic impact behavior. True async execution requires Modal. """ - report = session.get(Report, report_id) + report = session.get(Report, UUID(report_id)) if not report: return From 298540b4372127bcba043ae13bfc70232c640053 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Feb 2026 21:47:33 +0100 Subject: [PATCH 15/19] refactor: Use policyengine.py's Simulation class directly now that US reform bug is fixed MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit removes the workaround code that was added to bypass policyengine.py's Simulation class. The workaround was needed because policyengine.py's US simulation applied reforms via p.update() after Microsimulation construction, which didn't work due to the US package's shared singleton TaxBenefitSystem. That bug has now been fixed in policyengine.py (issue #232), so we can use policyengine.py's Simulation class directly again. Changes: - Revert household.py to use policyengine.core.Simulation instead of manually building Microsimulation with reform dicts - Revert modal_app.py to use PESimulation instead of custom helper functions (_pe_policy_to_reform_dict, _merge_reform_dicts, _run_us_economy_simulation, _run_uk_economy_simulation) - Remove now-obsolete test files for the workaround functions šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- src/policyengine_api/api/household.py | 388 +++----- src/policyengine_api/modal_app.py | 1153 +---------------------- test_fixtures/fixtures_policy_reform.py | 282 ------ tests/test_policy_reform.py | 327 ------- 4 files changed, 186 insertions(+), 1964 deletions(-) delete mode 100644 test_fixtures/fixtures_policy_reform.py delete mode 100644 tests/test_policy_reform.py diff --git a/src/policyengine_api/api/household.py b/src/policyengine_api/api/household.py index adb6ac9..0e89b5e 100644 --- a/src/policyengine_api/api/household.py +++ b/src/policyengine_api/api/household.py @@ -294,16 +294,17 @@ def _calculate_household_uk( Supports multiple households via entity relational dataframes. If entity IDs are not provided, defaults to single household with all people in it. - - Uses policyengine-uk Microsimulation directly with reform dict to ensure - policy changes are applied correctly. """ - import numpy as np + import tempfile + from datetime import datetime + from pathlib import Path + import pandas as pd + from policyengine.core import Simulation + from microdf import MicroDataFrame from policyengine.tax_benefit_models.uk import uk_latest - from policyengine_core.simulations.simulation_builder import SimulationBuilder - from policyengine_uk import Microsimulation - from policyengine_uk.system import system + from policyengine.tax_benefit_models.uk.datasets import PolicyEngineUKDataset + from policyengine.tax_benefit_models.uk.datasets import UKYearData n_people = len(people) n_benunits = max(1, len(benunit)) @@ -349,88 +350,68 @@ def _calculate_household_uk( household_data[key] = [0.0] * n_households household_data[key][i] = value - # Convert policy_data to policyengine-uk reform dict format - # Format: {"param.name": {"YYYY-MM-DD": value}} - reform = None - if policy_data and policy_data.get("parameter_values"): - reform = {} - for pv in policy_data["parameter_values"]: - param_name = pv.get("parameter_name") - value = pv.get("value") - start_date = pv.get("start_date") - - if param_name and start_date: - # Parse ISO date string to get just the date part - if "T" in start_date: - date_str = start_date.split("T")[0] - else: - date_str = start_date - - if param_name not in reform: - reform[param_name] = {} - reform[param_name][date_str] = value - - # Create Microsimulation with reform applied at construction time - microsim = Microsimulation(reform=reform) - - # Build simulation from entity data using SimulationBuilder - person_df = pd.DataFrame(person_data) - - # Determine column naming convention - benunit_id_col = ( - "person_benunit_id" - if "person_benunit_id" in person_df.columns - else "benunit_id" - ) - household_id_col = ( - "person_household_id" - if "person_household_id" in person_df.columns - else "household_id" + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" ) - # Declare entities using SimulationBuilder - builder = SimulationBuilder() - builder.populations = system.instantiate_entities() + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUKDataset( + name="Household calculation", + description="Household(s) for calculation", + filepath=filepath, + year=year, + data=UKYearData( + person=person_df, + benunit=benunit_df, + household=household_df, + ), + ) - builder.declare_person_entity("person", person_df["person_id"].values) - builder.declare_entity("benunit", np.unique(person_df[benunit_id_col].values)) - builder.declare_entity("household", np.unique(person_df[household_id_col].values)) + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + + pe_param_values = [] + param_lookup = {p.name: p for p in uk_latest.parameters} + for pv in policy_data.get("parameter_values", []): + pe_param = param_lookup.get(pv["parameter_name"]) + if pe_param: + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv["value"], + start_date=datetime.fromisoformat(pv["start_date"]) + if pv.get("start_date") + else None, + end_date=datetime.fromisoformat(pv["end_date"]) + if pv.get("end_date") + else None, + ) + pe_param_values.append(pe_pv) + policy = PEPolicy( + name=policy_data.get("name", ""), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) - # Join persons to group entities - builder.join_with_persons( - builder.populations["benunit"], - person_df[benunit_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["household"], - person_df[household_id_col].values, - np.array(["member"] * len(person_df)), + # Run simulation + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=uk_latest, + policy=policy, ) + simulation.run() - # Build simulation from populations - microsim.build_from_populations(builder.populations) + # Extract outputs + output_data = simulation.output_dataset.data - # Set input variables for each entity - id_columns = { - "person_id", - "benunit_id", - "person_benunit_id", - "household_id", - "person_household_id", - } - - for entity_name, entity_df in [ - ("person", person_data), - ("benunit", benunit_data), - ("household", household_data), - ]: - df = pd.DataFrame(entity_df) - for column in df.columns: - if column not in id_columns and column in system.variables: - microsim.set_input(column, year, df[column].values) - - # Calculate output variables def safe_convert(value): try: return float(value) @@ -441,24 +422,21 @@ def safe_convert(value): for i in range(n_people): person_dict = {} for var in uk_latest.entity_variables["person"]: - val = microsim.calculate(var, period=year, map_to="person") - person_dict[var] = safe_convert(val.values[i]) + person_dict[var] = safe_convert(output_data.person[var].iloc[i]) person_outputs.append(person_dict) benunit_outputs = [] - for i in range(n_benunits): + for i in range(len(output_data.benunit)): benunit_dict = {} for var in uk_latest.entity_variables["benunit"]: - val = microsim.calculate(var, period=year, map_to="benunit") - benunit_dict[var] = safe_convert(val.values[i]) + benunit_dict[var] = safe_convert(output_data.benunit[var].iloc[i]) benunit_outputs.append(benunit_dict) household_outputs = [] - for i in range(n_households): + for i in range(len(output_data.household)): household_dict = {} for var in uk_latest.entity_variables["household"]: - val = microsim.calculate(var, period=year, map_to="household") - household_dict[var] = safe_convert(val.values[i]) + household_dict[var] = safe_convert(output_data.household[var].iloc[i]) household_outputs.append(household_dict) return { @@ -488,14 +466,7 @@ def _run_local_household_us( try: result = _calculate_household_us( - people, - marital_unit, - family, - spm_unit, - tax_unit, - household, - year, - policy_data, + people, marital_unit, family, spm_unit, tax_unit, household, year, policy_data ) # Update job with result @@ -535,16 +506,17 @@ def _calculate_household_us( Supports multiple households via entity relational dataframes. If entity IDs are not provided, defaults to single household with all people in it. - - Uses policyengine-us Microsimulation directly with reform dict to ensure - policy changes are applied correctly. """ - import numpy as np + import tempfile + from datetime import datetime + from pathlib import Path + import pandas as pd + from policyengine.core import Simulation + from microdf import MicroDataFrame from policyengine.tax_benefit_models.us import us_latest - from policyengine_core.simulations.simulation_builder import SimulationBuilder - from policyengine_us import Microsimulation - from policyengine_us.system import system + from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset + from policyengine.tax_benefit_models.us.datasets import USYearData n_people = len(people) n_households = max(1, len(household)) @@ -624,158 +596,108 @@ def _calculate_household_us( tax_unit_data[key] = [0.0] * n_tax_units tax_unit_data[key][i] = value - # Convert policy_data to policyengine-us reform dict format - # Format: {"param.name": {"YYYY-MM-DD": value}} - reform = None - if policy_data and policy_data.get("parameter_values"): - reform = {} - for pv in policy_data["parameter_values"]: - param_name = pv.get("parameter_name") - value = pv.get("value") - start_date = pv.get("start_date") - - if param_name and start_date: - # Parse ISO date string to get just the date part - if "T" in start_date: - date_str = start_date.split("T")[0] - else: - date_str = start_date - - if param_name not in reform: - reform[param_name] = {} - reform[param_name][date_str] = value - - # Create Microsimulation with reform applied at construction time - # This ensures the reform is properly integrated into the tax benefit system - microsim = Microsimulation(reform=reform) - - # Build simulation from entity data using SimulationBuilder - person_df = pd.DataFrame(person_data) - - # Determine column naming convention - household_id_col = ( - "person_household_id" - if "person_household_id" in person_df.columns - else "household_id" - ) - marital_unit_id_col = ( - "person_marital_unit_id" - if "person_marital_unit_id" in person_df.columns - else "marital_unit_id" - ) - family_id_col = ( - "person_family_id" if "person_family_id" in person_df.columns else "family_id" + # Create MicroDataFrames + person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") + household_df = MicroDataFrame( + pd.DataFrame(household_data), weights="household_weight" ) - spm_unit_id_col = ( - "person_spm_unit_id" - if "person_spm_unit_id" in person_df.columns - else "spm_unit_id" + marital_unit_df = MicroDataFrame( + pd.DataFrame(marital_unit_data), weights="marital_unit_weight" ) - tax_unit_id_col = ( - "person_tax_unit_id" - if "person_tax_unit_id" in person_df.columns - else "tax_unit_id" + family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") + spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") + tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") + + # Create temporary dataset + tmpdir = tempfile.mkdtemp() + filepath = str(Path(tmpdir) / "household_calc.h5") + + dataset = PolicyEngineUSDataset( + name="Household calculation", + description="Household(s) for calculation", + filepath=filepath, + year=year, + data=USYearData( + person=person_df, + household=household_df, + marital_unit=marital_unit_df, + family=family_df, + spm_unit=spm_unit_df, + tax_unit=tax_unit_df, + ), ) - # Declare entities using SimulationBuilder - builder = SimulationBuilder() - builder.populations = system.instantiate_entities() - - builder.declare_person_entity("person", person_df["person_id"].values) - builder.declare_entity("household", np.unique(person_df[household_id_col].values)) - builder.declare_entity("spm_unit", np.unique(person_df[spm_unit_id_col].values)) - builder.declare_entity("family", np.unique(person_df[family_id_col].values)) - builder.declare_entity("tax_unit", np.unique(person_df[tax_unit_id_col].values)) - builder.declare_entity( - "marital_unit", np.unique(person_df[marital_unit_id_col].values) - ) + # Build policy if provided + policy = None + if policy_data: + from policyengine.core.policy import ParameterValue as PEParameterValue + from policyengine.core.policy import Policy as PEPolicy + + pe_param_values = [] + param_lookup = {p.name: p for p in us_latest.parameters} + for pv in policy_data.get("parameter_values", []): + pe_param = param_lookup.get(pv["parameter_name"]) + if pe_param: + pe_pv = PEParameterValue( + parameter=pe_param, + value=pv["value"], + start_date=datetime.fromisoformat(pv["start_date"]) + if pv.get("start_date") + else None, + end_date=datetime.fromisoformat(pv["end_date"]) + if pv.get("end_date") + else None, + ) + pe_param_values.append(pe_pv) + policy = PEPolicy( + name=policy_data.get("name", ""), + description=policy_data.get("description", ""), + parameter_values=pe_param_values, + ) - # Join persons to group entities - builder.join_with_persons( - builder.populations["household"], - person_df[household_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["spm_unit"], - person_df[spm_unit_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["family"], - person_df[family_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["tax_unit"], - person_df[tax_unit_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["marital_unit"], - person_df[marital_unit_id_col].values, - np.array(["member"] * len(person_df)), + # Run simulation + simulation = Simulation( + dataset=dataset, + tax_benefit_model_version=us_latest, + policy=policy, ) + simulation.run() - # Build simulation from populations - microsim.build_from_populations(builder.populations) - - # Set input variables for each entity - id_columns = { - "person_id", - "household_id", - "person_household_id", - "spm_unit_id", - "person_spm_unit_id", - "family_id", - "person_family_id", - "tax_unit_id", - "person_tax_unit_id", - "marital_unit_id", - "person_marital_unit_id", - } + # Extract outputs + output_data = simulation.output_dataset.data - for entity_name, entity_df in [ - ("person", person_data), - ("household", household_data), - ("spm_unit", spm_unit_data), - ("family", family_data), - ("tax_unit", tax_unit_data), - ("marital_unit", marital_unit_data), - ]: - df = pd.DataFrame(entity_df) - for column in df.columns: - if column not in id_columns and column in system.variables: - microsim.set_input(column, year, df[column].values) - - # Calculate output variables def safe_convert(value): try: return float(value) except (ValueError, TypeError): return str(value) - def extract_entity_outputs( - entity_name: str, n_rows: int, map_to: str - ) -> list[dict]: + def extract_entity_outputs(entity_name: str, entity_data, n_rows: int) -> list[dict]: outputs = [] for i in range(n_rows): row_dict = {} for var in us_latest.entity_variables[entity_name]: - val = microsim.calculate(var, period=year, map_to=map_to) - row_dict[var] = safe_convert(val.values[i]) + row_dict[var] = safe_convert(entity_data[var].iloc[i]) outputs.append(row_dict) return outputs return { - "person": extract_entity_outputs("person", n_people, "person"), + "person": extract_entity_outputs("person", output_data.person, n_people), "marital_unit": extract_entity_outputs( - "marital_unit", n_marital_units, "marital_unit" + "marital_unit", output_data.marital_unit, len(output_data.marital_unit) + ), + "family": extract_entity_outputs( + "family", output_data.family, len(output_data.family) + ), + "spm_unit": extract_entity_outputs( + "spm_unit", output_data.spm_unit, len(output_data.spm_unit) + ), + "tax_unit": extract_entity_outputs( + "tax_unit", output_data.tax_unit, len(output_data.tax_unit) + ), + "household": extract_entity_outputs( + "household", output_data.household, len(output_data.household) ), - "family": extract_entity_outputs("family", n_families, "family"), - "spm_unit": extract_entity_outputs("spm_unit", n_spm_units, "spm_unit"), - "tax_unit": extract_entity_outputs("tax_unit", n_tax_units, "tax_unit"), - "household": extract_entity_outputs("household", n_households, "household"), } diff --git a/src/policyengine_api/modal_app.py b/src/policyengine_api/modal_app.py index 2b486f3..1aa8119 100644 --- a/src/policyengine_api/modal_app.py +++ b/src/policyengine_api/modal_app.py @@ -7,8 +7,7 @@ Function naming follows the API hierarchy: - simulate_household_*: Single household calculation (/simulate/household) - simulate_economy_*: Single economy simulation (/simulate/economy) -- economy_comparison_*: Full economy comparison analysis (/analysis/economic-impact) -- household_impact_*: Household impact analysis (/analysis/household-impact) +- economy_comparison_*: Full economy comparison analysis (/analysis/compare/economy) Deploy with: modal deploy src/policyengine_api/modal_app.py """ @@ -807,6 +806,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N raise ValueError(f"Dataset {simulation.dataset_id} not found") # Import policyengine + from policyengine.core import Simulation as PESimulation from policyengine.tax_benefit_models.uk import uk_latest from policyengine.tax_benefit_models.uk.datasets import ( PolicyEngineUKDataset, @@ -814,7 +814,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N pe_model_version = uk_latest - # Get policy and dynamic as PEPolicy/PEDynamic objects + # Get policy and dynamic policy = _get_pe_policy_uk( simulation.policy_id, pe_model_version, session ) @@ -822,13 +822,6 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N simulation.dynamic_id, pe_model_version, session ) - # Convert to reform dict format for Microsimulation - # This is necessary because policyengine.core.Simulation applies - # reforms AFTER creating Microsimulation, which doesn't work - policy_reform = _pe_policy_to_reform_dict(policy) - dynamic_reform = _pe_policy_to_reform_dict(dynamic) - reform = _merge_reform_dicts(policy_reform, dynamic_reform) - # Download dataset local_path = download_dataset( dataset.filepath, supabase_url, supabase_key, storage_bucket @@ -841,12 +834,15 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N year=dataset.year, ) - # Run simulation using Microsimulation directly with reform - # This ensures reforms are applied at construction time + # Create and run simulation with logfire.span("run_simulation"): - pe_output_dataset = _run_uk_economy_simulation( - pe_dataset, reform, pe_model_version, simulation_id + pe_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=policy, + dynamic=dynamic, ) + pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): @@ -856,8 +852,8 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N output_path = f"/tmp/{output_filename}" # Set filepath and save - pe_output_dataset.filepath = output_path - pe_output_dataset.save() + pe_sim.output_dataset.filepath = output_path + pe_sim.output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) @@ -872,7 +868,7 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N ) # Create output dataset record - output_dataset_record = Dataset( + output_dataset = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", filepath=output_filename, @@ -880,12 +876,12 @@ def simulate_economy_uk(simulation_id: str, traceparent: str | None = None) -> N is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, ) - session.add(output_dataset_record) + session.add(output_dataset) session.commit() - session.refresh(output_dataset_record) + session.refresh(output_dataset) # Link to simulation - simulation.output_dataset_id = output_dataset_record.id + simulation.output_dataset_id = output_dataset.id # Mark as completed simulation.status = SimulationStatus.COMPLETED @@ -976,15 +972,15 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N raise ValueError(f"Dataset {simulation.dataset_id} not found") # Import policyengine + from policyengine.core import Simulation as PESimulation from policyengine.tax_benefit_models.us import us_latest from policyengine.tax_benefit_models.us.datasets import ( PolicyEngineUSDataset, - USYearData, ) pe_model_version = us_latest - # Get policy and dynamic as PEPolicy/PEDynamic objects + # Get policy and dynamic policy = _get_pe_policy_us( simulation.policy_id, pe_model_version, session ) @@ -992,13 +988,6 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N simulation.dynamic_id, pe_model_version, session ) - # Convert to reform dict format for Microsimulation - # This is necessary because policyengine.core.Simulation applies - # reforms AFTER creating Microsimulation, which doesn't work - policy_reform = _pe_policy_to_reform_dict(policy) - dynamic_reform = _pe_policy_to_reform_dict(dynamic) - reform = _merge_reform_dicts(policy_reform, dynamic_reform) - # Download dataset local_path = download_dataset( dataset.filepath, supabase_url, supabase_key, storage_bucket @@ -1011,12 +1000,15 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N year=dataset.year, ) - # Run simulation using Microsimulation directly with reform - # This ensures reforms are applied at construction time + # Create and run simulation with logfire.span("run_simulation"): - pe_output_dataset = _run_us_economy_simulation( - pe_dataset, reform, pe_model_version, simulation_id + pe_sim = PESimulation( + dataset=pe_dataset, + tax_benefit_model_version=pe_model_version, + policy=policy, + dynamic=dynamic, ) + pe_sim.ensure() # Save output dataset with logfire.span("save_output_dataset"): @@ -1026,8 +1018,8 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N output_path = f"/tmp/{output_filename}" # Set filepath and save - pe_output_dataset.filepath = output_path - pe_output_dataset.save() + pe_sim.output_dataset.filepath = output_path + pe_sim.output_dataset.save() # Upload to Supabase storage supabase = create_client(supabase_url, supabase_key) @@ -1042,7 +1034,7 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N ) # Create output dataset record - output_dataset_record = Dataset( + output_dataset = Dataset( name=f"Output: {dataset.name}", description=f"Output from simulation {simulation_id}", filepath=output_filename, @@ -1050,12 +1042,12 @@ def simulate_economy_us(simulation_id: str, traceparent: str | None = None) -> N is_output_dataset=True, tax_benefit_model_id=dataset.tax_benefit_model_id, ) - session.add(output_dataset_record) + session.add(output_dataset) session.commit() - session.refresh(output_dataset_record) + session.refresh(output_dataset) # Link to simulation - simulation.output_dataset_id = output_dataset_record.id + simulation.output_dataset_id = output_dataset.id # Mark as completed simulation.status = SimulationStatus.COMPLETED @@ -1823,403 +1815,6 @@ def _get_pe_dynamic_us(dynamic_id, model_version, session): return _get_pe_dynamic_uk(dynamic_id, model_version, session) -def _pe_policy_to_reform_dict(policy) -> dict | None: - """Convert a policyengine.core.policy.Policy to reform dict format. - - The policyengine-us/uk Microsimulation expects reforms in the format: - {"parameter.name": {"YYYY-MM-DD": value}} - - This is necessary because the policyengine.core.Simulation applies reforms - AFTER creating the Microsimulation, which doesn't work due to caching. - We need to pass the reform at Microsimulation construction time. - """ - if policy is None: - return None - - if not policy.parameter_values: - return None - - reform = {} - for pv in policy.parameter_values: - if not pv.parameter: - continue - param_name = pv.parameter.name - value = pv.value - start_date = pv.start_date - - if param_name and start_date: - # Format date as YYYY-MM-DD string - if hasattr(start_date, "strftime"): - date_str = start_date.strftime("%Y-%m-%d") - else: - date_str = str(start_date).split("T")[0] - - if param_name not in reform: - reform[param_name] = {} - reform[param_name][date_str] = value - - return reform if reform else None - - -def _merge_reform_dicts(reform1: dict | None, reform2: dict | None) -> dict | None: - """Merge two reform dicts, with reform2 taking precedence.""" - if reform1 is None and reform2 is None: - return None - if reform1 is None: - return reform2 - if reform2 is None: - return reform1 - - merged = dict(reform1) - for param_name, dates in reform2.items(): - if param_name not in merged: - merged[param_name] = {} - merged[param_name].update(dates) - return merged - - -def _run_us_economy_simulation(pe_dataset, reform, pe_model_version, simulation_id): - """Run US economy simulation using Microsimulation directly. - - This bypasses policyengine.core.Simulation which has a bug where reforms - are applied AFTER creating Microsimulation (when it's too late). - Instead, we pass the reform dict at Microsimulation construction time. - """ - from pathlib import Path - - import numpy as np - import pandas as pd - from microdf import MicroDataFrame - from policyengine.tax_benefit_models.us.datasets import ( - PolicyEngineUSDataset, - USYearData, - ) - from policyengine_core.simulations.simulation_builder import SimulationBuilder - from policyengine_us import Microsimulation - from policyengine_us.system import system - - # Load dataset - pe_dataset.load() - year = pe_dataset.year - - # Create Microsimulation with reform applied at construction time - microsim = Microsimulation(reform=reform) - - # Build simulation from dataset using SimulationBuilder - person_df = pd.DataFrame(pe_dataset.data.person) - - # Determine column naming convention - household_id_col = ( - "person_household_id" - if "person_household_id" in person_df.columns - else "household_id" - ) - marital_unit_id_col = ( - "person_marital_unit_id" - if "person_marital_unit_id" in person_df.columns - else "marital_unit_id" - ) - family_id_col = ( - "person_family_id" if "person_family_id" in person_df.columns else "family_id" - ) - spm_unit_id_col = ( - "person_spm_unit_id" - if "person_spm_unit_id" in person_df.columns - else "spm_unit_id" - ) - tax_unit_id_col = ( - "person_tax_unit_id" - if "person_tax_unit_id" in person_df.columns - else "tax_unit_id" - ) - - # Declare entities - builder = SimulationBuilder() - builder.populations = system.instantiate_entities() - - builder.declare_person_entity("person", person_df["person_id"].values) - builder.declare_entity("household", np.unique(person_df[household_id_col].values)) - builder.declare_entity("spm_unit", np.unique(person_df[spm_unit_id_col].values)) - builder.declare_entity("family", np.unique(person_df[family_id_col].values)) - builder.declare_entity("tax_unit", np.unique(person_df[tax_unit_id_col].values)) - builder.declare_entity( - "marital_unit", np.unique(person_df[marital_unit_id_col].values) - ) - - # Join persons to entities - builder.join_with_persons( - builder.populations["household"], - person_df[household_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["spm_unit"], - person_df[spm_unit_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["family"], - person_df[family_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["tax_unit"], - person_df[tax_unit_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["marital_unit"], - person_df[marital_unit_id_col].values, - np.array(["member"] * len(person_df)), - ) - - microsim.build_from_populations(builder.populations) - - # Set input variables - id_columns = { - "person_id", - "household_id", - "person_household_id", - "spm_unit_id", - "person_spm_unit_id", - "family_id", - "person_family_id", - "tax_unit_id", - "person_tax_unit_id", - "marital_unit_id", - "person_marital_unit_id", - } - - for entity_name, entity_data in [ - ("person", pe_dataset.data.person), - ("household", pe_dataset.data.household), - ("spm_unit", pe_dataset.data.spm_unit), - ("family", pe_dataset.data.family), - ("tax_unit", pe_dataset.data.tax_unit), - ("marital_unit", pe_dataset.data.marital_unit), - ]: - df = pd.DataFrame(entity_data) - for column in df.columns: - if column not in id_columns and column in system.variables: - microsim.set_input(column, year, df[column].values) - - # Calculate output variables and build output dataset - data = { - "person": pd.DataFrame(), - "marital_unit": pd.DataFrame(), - "family": pd.DataFrame(), - "spm_unit": pd.DataFrame(), - "tax_unit": pd.DataFrame(), - "household": pd.DataFrame(), - } - - weight_columns = { - "person_weight", - "household_weight", - "marital_unit_weight", - "family_weight", - "spm_unit_weight", - "tax_unit_weight", - } - - # Copy ID and weight columns from input dataset - for entity in data.keys(): - input_df = pd.DataFrame(getattr(pe_dataset.data, entity)) - entity_id_col = f"{entity}_id" - entity_weight_col = f"{entity}_weight" - - if entity_id_col in input_df.columns: - data[entity][entity_id_col] = input_df[entity_id_col].values - if entity_weight_col in input_df.columns: - data[entity][entity_weight_col] = input_df[entity_weight_col].values - - # Copy person-level group ID columns - for col in person_df.columns: - if col.startswith("person_") and col.endswith("_id"): - target_col = col.replace("person_", "") - if target_col in id_columns: - data["person"][target_col] = person_df[col].values - - # Calculate non-ID, non-weight variables - for entity, variables in pe_model_version.entity_variables.items(): - for var in variables: - if var not in id_columns and var not in weight_columns: - data[entity][var] = microsim.calculate( - var, period=year, map_to=entity - ).values - - # Convert to MicroDataFrames - data["person"] = MicroDataFrame(data["person"], weights="person_weight") - data["marital_unit"] = MicroDataFrame( - data["marital_unit"], weights="marital_unit_weight" - ) - data["family"] = MicroDataFrame(data["family"], weights="family_weight") - data["spm_unit"] = MicroDataFrame(data["spm_unit"], weights="spm_unit_weight") - data["tax_unit"] = MicroDataFrame(data["tax_unit"], weights="tax_unit_weight") - data["household"] = MicroDataFrame(data["household"], weights="household_weight") - - # Create output dataset - return PolicyEngineUSDataset( - id=simulation_id, - name=pe_dataset.name, - description=pe_dataset.description, - filepath=str(Path(pe_dataset.filepath).parent / (simulation_id + ".h5")), - year=year, - is_output_dataset=True, - data=USYearData( - person=data["person"], - marital_unit=data["marital_unit"], - family=data["family"], - spm_unit=data["spm_unit"], - tax_unit=data["tax_unit"], - household=data["household"], - ), - ) - - -def _run_uk_economy_simulation(pe_dataset, reform, pe_model_version, simulation_id): - """Run UK economy simulation using Microsimulation directly. - - This bypasses policyengine.core.Simulation which has a bug where reforms - are applied AFTER creating Microsimulation (when it's too late). - Instead, we pass the reform dict at Microsimulation construction time. - """ - from pathlib import Path - - import numpy as np - import pandas as pd - from microdf import MicroDataFrame - from policyengine.tax_benefit_models.uk.datasets import ( - PolicyEngineUKDataset, - UKYearData, - ) - from policyengine_core.simulations.simulation_builder import SimulationBuilder - from policyengine_uk import Microsimulation - from policyengine_uk.system import system - - # Load dataset - pe_dataset.load() - year = pe_dataset.year - - # Create Microsimulation with reform applied at construction time - microsim = Microsimulation(reform=reform) - - # Build simulation from dataset using SimulationBuilder - person_df = pd.DataFrame(pe_dataset.data.person) - - # Determine column naming convention - benunit_id_col = ( - "person_benunit_id" - if "person_benunit_id" in person_df.columns - else "benunit_id" - ) - household_id_col = ( - "person_household_id" - if "person_household_id" in person_df.columns - else "household_id" - ) - - # Declare entities - builder = SimulationBuilder() - builder.populations = system.instantiate_entities() - - builder.declare_person_entity("person", person_df["person_id"].values) - builder.declare_entity("benunit", np.unique(person_df[benunit_id_col].values)) - builder.declare_entity("household", np.unique(person_df[household_id_col].values)) - - # Join persons to entities - builder.join_with_persons( - builder.populations["benunit"], - person_df[benunit_id_col].values, - np.array(["member"] * len(person_df)), - ) - builder.join_with_persons( - builder.populations["household"], - person_df[household_id_col].values, - np.array(["member"] * len(person_df)), - ) - - microsim.build_from_populations(builder.populations) - - # Set input variables - id_columns = { - "person_id", - "benunit_id", - "person_benunit_id", - "household_id", - "person_household_id", - } - - for entity_name, entity_data in [ - ("person", pe_dataset.data.person), - ("benunit", pe_dataset.data.benunit), - ("household", pe_dataset.data.household), - ]: - df = pd.DataFrame(entity_data) - for column in df.columns: - if column not in id_columns and column in system.variables: - microsim.set_input(column, year, df[column].values) - - # Calculate output variables and build output dataset - data = { - "person": pd.DataFrame(), - "benunit": pd.DataFrame(), - "household": pd.DataFrame(), - } - - weight_columns = { - "person_weight", - "benunit_weight", - "household_weight", - } - - # Copy ID and weight columns from input dataset - for entity in data.keys(): - input_df = pd.DataFrame(getattr(pe_dataset.data, entity)) - entity_id_col = f"{entity}_id" - entity_weight_col = f"{entity}_weight" - - if entity_id_col in input_df.columns: - data[entity][entity_id_col] = input_df[entity_id_col].values - if entity_weight_col in input_df.columns: - data[entity][entity_weight_col] = input_df[entity_weight_col].values - - # Copy person-level group ID columns - for col in person_df.columns: - if col.startswith("person_") and col.endswith("_id"): - target_col = col.replace("person_", "") - if target_col in id_columns: - data["person"][target_col] = person_df[col].values - - # Calculate non-ID, non-weight variables - for entity, variables in pe_model_version.entity_variables.items(): - for var in variables: - if var not in id_columns and var not in weight_columns: - data[entity][var] = microsim.calculate( - var, period=year, map_to=entity - ).values - - # Convert to MicroDataFrames - data["person"] = MicroDataFrame(data["person"], weights="person_weight") - data["benunit"] = MicroDataFrame(data["benunit"], weights="benunit_weight") - data["household"] = MicroDataFrame(data["household"], weights="household_weight") - - # Create output dataset - return PolicyEngineUKDataset( - id=simulation_id, - name=pe_dataset.name, - description=pe_dataset.description, - filepath=str(Path(pe_dataset.filepath).parent / (simulation_id + ".h5")), - year=year, - is_output_dataset=True, - data=UKYearData( - person=data["person"], - benunit=data["benunit"], - household=data["household"], - ), - ) - - @app.function( image=uk_image, secrets=[db_secrets, logfire_secrets], @@ -2921,689 +2516,3 @@ def compute_change_aggregate_us( raise finally: logfire.force_flush() - - -# ============================================================================= -# Household Impact Functions -# ============================================================================= - - -@app.function( - image=uk_image, - secrets=[db_secrets, logfire_secrets], - memory=2048, - cpu=2, - timeout=300, -) -def household_impact_uk(report_id: str, traceparent: str | None = None) -> None: - """Run UK household impact analysis and write results to database.""" - import logfire - - configure_logfire("policyengine-modal-uk", traceparent) - - try: - with logfire.span("household_impact_uk", report_id=report_id): - from datetime import datetime, timezone - from uuid import UUID - - from sqlmodel import Session, create_engine - - database_url = get_database_url() - engine = create_engine(database_url) - - try: - from policyengine_api.models import ( - Household, - Report, - ReportStatus, - Simulation, - SimulationStatus, - ) - - with Session(engine) as session: - # Load report - report = session.get(Report, UUID(report_id)) - if not report: - raise ValueError(f"Report {report_id} not found") - - # Mark as running - report.status = ReportStatus.RUNNING - session.add(report) - session.commit() - - # Run baseline simulation - if report.baseline_simulation_id: - _run_household_simulation_uk( - report.baseline_simulation_id, session - ) - - # Run reform simulation if present - if report.reform_simulation_id: - _run_household_simulation_uk( - report.reform_simulation_id, session - ) - - # Mark report as completed - report.status = ReportStatus.COMPLETED - session.add(report) - session.commit() - - except Exception as e: - logfire.error( - "UK household impact failed", report_id=report_id, error=str(e) - ) - try: - from sqlmodel import text - - with Session(engine) as session: - session.execute( - text( - "UPDATE reports SET status = 'FAILED', error_message = :error " - "WHERE id = :report_id" - ), - {"report_id": report_id, "error": str(e)[:1000]}, - ) - session.commit() - except Exception as db_error: - logfire.error("Failed to update DB", error=str(db_error)) - raise - finally: - logfire.force_flush() - - -def _run_household_simulation_uk(simulation_id, session) -> None: - """Run a single UK household simulation.""" - from datetime import datetime, timezone - - from policyengine_api.models import ( - Household, - Simulation, - SimulationStatus, - ) - - simulation = session.get(Simulation, simulation_id) - if not simulation or simulation.status != SimulationStatus.PENDING: - return - - household = session.get(Household, simulation.household_id) - if not household: - raise ValueError(f"Household {simulation.household_id} not found") - - # Mark as running - simulation.status = SimulationStatus.RUNNING - simulation.started_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - - try: - # Get policy data if present - policy_data = _get_household_policy_data(simulation.policy_id, session) - - # Run calculation - result = _calculate_uk_household( - household.household_data, - household.year, - policy_data, - ) - - # Store result - simulation.household_result = result - simulation.status = SimulationStatus.COMPLETED - simulation.completed_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - except Exception as e: - simulation.status = SimulationStatus.FAILED - simulation.error_message = str(e) - simulation.completed_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - raise - - -def _calculate_uk_household( - household_data: dict, year: int, policy_data: dict | None -) -> dict: - """Calculate UK household and return result dict.""" - import tempfile - from pathlib import Path - - import pandas as pd - from microdf import MicroDataFrame - from policyengine.core import Simulation - from policyengine.tax_benefit_models.uk import uk_latest - from policyengine.tax_benefit_models.uk.datasets import ( - PolicyEngineUKDataset, - UKYearData, - ) - - people = household_data.get("people", []) - benunit = household_data.get("benunit", []) - hh = household_data.get("household", []) - - # Ensure lists - if isinstance(benunit, dict): - benunit = [benunit] - if isinstance(hh, dict): - hh = [hh] - - n_people = len(people) - n_benunits = max(1, len(benunit) if benunit else 1) - n_households = max(1, len(hh) if hh else 1) - - # Build person data - person_data = { - "person_id": list(range(n_people)), - "person_benunit_id": [0] * n_people, - "person_household_id": [0] * n_people, - "person_weight": [1.0] * n_people, - } - for i, person in enumerate(people): - for key, value in person.items(): - if key not in person_data: - person_data[key] = [0.0] * n_people - person_data[key][i] = value - - # Build benunit data - benunit_data = { - "benunit_id": list(range(n_benunits)), - "benunit_weight": [1.0] * n_benunits, - } - for i, bu in enumerate(benunit if benunit else [{}]): - for key, value in bu.items(): - if key not in benunit_data: - benunit_data[key] = [0.0] * n_benunits - benunit_data[key][i] = value - - # Build household data - household_df_data = { - "household_id": list(range(n_households)), - "household_weight": [1.0] * n_households, - "region": ["LONDON"] * n_households, - "tenure_type": ["RENT_PRIVATELY"] * n_households, - "council_tax": [0.0] * n_households, - "rent": [0.0] * n_households, - } - for i, h in enumerate(hh if hh else [{}]): - for key, value in h.items(): - if key not in household_df_data: - household_df_data[key] = [0.0] * n_households - household_df_data[key][i] = value - - # Create MicroDataFrames - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - benunit_df = MicroDataFrame(pd.DataFrame(benunit_data), weights="benunit_weight") - household_df = MicroDataFrame( - pd.DataFrame(household_df_data), weights="household_weight" - ) - - # Create temporary dataset - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "household_calc.h5") - - dataset = PolicyEngineUKDataset( - name="Household calculation", - description="Household(s) for calculation", - person=person_df, - benunit=benunit_df, - household=household_df, - filepath=filepath, - year_data_class=UKYearData, - ) - dataset.save() - - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue, Policy - - pe_param_values = [] - param_lookup = {p.name: p for p in uk_latest.parameters} - for pv in policy_data.get("parameter_values", []): - param_name = pv.get("parameter_name") - if param_name and param_name in param_lookup: - pe_pv = ParameterValue( - parameter=param_lookup[param_name], - value=pv.get("value"), - start_date=pv.get("start_date"), - end_date=pv.get("end_date"), - ) - pe_param_values.append(pe_pv) - - if pe_param_values: - policy = Policy( - name=policy_data.get("name", "Reform"), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) - - # Run simulation - sim = Simulation( - dataset=dataset, - tax_benefit_model_version=uk_latest, - policy=policy, - ) - sim.ensure() - - # Extract results - result = {"person": [], "benunit": [], "household": []} - - for i in range(n_people): - person_result = {} - for var in sim.output_dataset.person.columns: - val = sim.output_dataset.person[var].iloc[i] - person_result[var] = float(val) if hasattr(val, "item") else val - result["person"].append(person_result) - - for i in range(n_benunits): - benunit_result = {} - for var in sim.output_dataset.benunit.columns: - val = sim.output_dataset.benunit[var].iloc[i] - benunit_result[var] = float(val) if hasattr(val, "item") else val - result["benunit"].append(benunit_result) - - for i in range(n_households): - household_result = {} - for var in sim.output_dataset.household.columns: - val = sim.output_dataset.household[var].iloc[i] - household_result[var] = float(val) if hasattr(val, "item") else val - result["household"].append(household_result) - - return result - - -@app.function( - image=us_image, - secrets=[db_secrets, logfire_secrets], - memory=2048, - cpu=2, - timeout=300, -) -def household_impact_us(report_id: str, traceparent: str | None = None) -> None: - """Run US household impact analysis and write results to database.""" - import logfire - - configure_logfire("policyengine-modal-us", traceparent) - - try: - with logfire.span("household_impact_us", report_id=report_id): - from datetime import datetime, timezone - from uuid import UUID - - from sqlmodel import Session, create_engine - - database_url = get_database_url() - engine = create_engine(database_url) - - try: - from policyengine_api.models import ( - Household, - Report, - ReportStatus, - Simulation, - SimulationStatus, - ) - - with Session(engine) as session: - # Load report - report = session.get(Report, UUID(report_id)) - if not report: - raise ValueError(f"Report {report_id} not found") - - # Mark as running - report.status = ReportStatus.RUNNING - session.add(report) - session.commit() - - # Run baseline simulation - if report.baseline_simulation_id: - _run_household_simulation_us( - report.baseline_simulation_id, session - ) - - # Run reform simulation if present - if report.reform_simulation_id: - _run_household_simulation_us( - report.reform_simulation_id, session - ) - - # Mark report as completed - report.status = ReportStatus.COMPLETED - session.add(report) - session.commit() - - except Exception as e: - logfire.error( - "US household impact failed", report_id=report_id, error=str(e) - ) - try: - from sqlmodel import text - - with Session(engine) as session: - session.execute( - text( - "UPDATE reports SET status = 'FAILED', error_message = :error " - "WHERE id = :report_id" - ), - {"report_id": report_id, "error": str(e)[:1000]}, - ) - session.commit() - except Exception as db_error: - logfire.error("Failed to update DB", error=str(db_error)) - raise - finally: - logfire.force_flush() - - -def _run_household_simulation_us(simulation_id, session) -> None: - """Run a single US household simulation.""" - from datetime import datetime, timezone - - from policyengine_api.models import ( - Household, - Simulation, - SimulationStatus, - ) - - simulation = session.get(Simulation, simulation_id) - if not simulation or simulation.status != SimulationStatus.PENDING: - return - - household = session.get(Household, simulation.household_id) - if not household: - raise ValueError(f"Household {simulation.household_id} not found") - - # Mark as running - simulation.status = SimulationStatus.RUNNING - simulation.started_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - - try: - # Get policy data if present - policy_data = _get_household_policy_data(simulation.policy_id, session) - - # Run calculation - result = _calculate_us_household( - household.household_data, - household.year, - policy_data, - ) - - # Store result - simulation.household_result = result - simulation.status = SimulationStatus.COMPLETED - simulation.completed_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - except Exception as e: - simulation.status = SimulationStatus.FAILED - simulation.error_message = str(e) - simulation.completed_at = datetime.now(timezone.utc) - session.add(simulation) - session.commit() - raise - - -def _calculate_us_household( - household_data: dict, year: int, policy_data: dict | None -) -> dict: - """Calculate US household and return result dict.""" - import tempfile - from pathlib import Path - - import pandas as pd - from microdf import MicroDataFrame - from policyengine.core import Simulation - from policyengine.tax_benefit_models.us import us_latest - from policyengine.tax_benefit_models.us.datasets import ( - PolicyEngineUSDataset, - USYearData, - ) - - people = household_data.get("people", []) - tax_unit = household_data.get("tax_unit", []) - family = household_data.get("family", []) - spm_unit = household_data.get("spm_unit", []) - marital_unit = household_data.get("marital_unit", []) - hh = household_data.get("household", []) - - # Ensure lists - if isinstance(tax_unit, dict): - tax_unit = [tax_unit] - if isinstance(family, dict): - family = [family] - if isinstance(spm_unit, dict): - spm_unit = [spm_unit] - if isinstance(marital_unit, dict): - marital_unit = [marital_unit] - if isinstance(hh, dict): - hh = [hh] - - n_people = len(people) - n_tax_units = max(1, len(tax_unit) if tax_unit else 1) - n_families = max(1, len(family) if family else 1) - n_spm_units = max(1, len(spm_unit) if spm_unit else 1) - n_marital_units = max(1, len(marital_unit) if marital_unit else 1) - n_households = max(1, len(hh) if hh else 1) - - # Build person data - person_data = { - "person_id": list(range(n_people)), - "person_tax_unit_id": [0] * n_people, - "person_family_id": [0] * n_people, - "person_spm_unit_id": [0] * n_people, - "person_marital_unit_id": [0] * n_people, - "person_household_id": [0] * n_people, - "person_weight": [1.0] * n_people, - } - for i, person in enumerate(people): - for key, value in person.items(): - if key not in person_data: - person_data[key] = [0.0] * n_people - person_data[key][i] = value - - # Build tax_unit data - tax_unit_data = { - "tax_unit_id": list(range(n_tax_units)), - "tax_unit_weight": [1.0] * n_tax_units, - } - for i, tu in enumerate(tax_unit if tax_unit else [{}]): - for key, value in tu.items(): - if key not in tax_unit_data: - tax_unit_data[key] = [0.0] * n_tax_units - tax_unit_data[key][i] = value - - # Build family data - family_data = { - "family_id": list(range(n_families)), - "family_weight": [1.0] * n_families, - } - for i, fam in enumerate(family if family else [{}]): - for key, value in fam.items(): - if key not in family_data: - family_data[key] = [0.0] * n_families - family_data[key][i] = value - - # Build spm_unit data - spm_unit_data = { - "spm_unit_id": list(range(n_spm_units)), - "spm_unit_weight": [1.0] * n_spm_units, - } - for i, spm in enumerate(spm_unit if spm_unit else [{}]): - for key, value in spm.items(): - if key not in spm_unit_data: - spm_unit_data[key] = [0.0] * n_spm_units - spm_unit_data[key][i] = value - - # Build marital_unit data - marital_unit_data = { - "marital_unit_id": list(range(n_marital_units)), - "marital_unit_weight": [1.0] * n_marital_units, - } - for i, mu in enumerate(marital_unit if marital_unit else [{}]): - for key, value in mu.items(): - if key not in marital_unit_data: - marital_unit_data[key] = [0.0] * n_marital_units - marital_unit_data[key][i] = value - - # Build household data - household_df_data = { - "household_id": list(range(n_households)), - "household_weight": [1.0] * n_households, - } - for i, h in enumerate(hh if hh else [{}]): - for key, value in h.items(): - if key not in household_df_data: - household_df_data[key] = [0.0] * n_households - household_df_data[key][i] = value - - # Create MicroDataFrames - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - tax_unit_df = MicroDataFrame( - pd.DataFrame(tax_unit_data), weights="tax_unit_weight" - ) - family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") - spm_unit_df = MicroDataFrame( - pd.DataFrame(spm_unit_data), weights="spm_unit_weight" - ) - marital_unit_df = MicroDataFrame( - pd.DataFrame(marital_unit_data), weights="marital_unit_weight" - ) - household_df = MicroDataFrame( - pd.DataFrame(household_df_data), weights="household_weight" - ) - - # Create temporary dataset - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "household_calc.h5") - - dataset = PolicyEngineUSDataset( - name="Household calculation", - description="Household(s) for calculation", - person=person_df, - tax_unit=tax_unit_df, - family=family_df, - spm_unit=spm_unit_df, - marital_unit=marital_unit_df, - household=household_df, - filepath=filepath, - year_data_class=USYearData, - ) - dataset.save() - - # Build policy if provided - policy = None - if policy_data: - from policyengine.core.policy import ParameterValue, Policy - - pe_param_values = [] - param_lookup = {p.name: p for p in us_latest.parameters} - for pv in policy_data.get("parameter_values", []): - param_name = pv.get("parameter_name") - if param_name and param_name in param_lookup: - pe_pv = ParameterValue( - parameter=param_lookup[param_name], - value=pv.get("value"), - start_date=pv.get("start_date"), - end_date=pv.get("end_date"), - ) - pe_param_values.append(pe_pv) - - if pe_param_values: - policy = Policy( - name=policy_data.get("name", "Reform"), - description=policy_data.get("description", ""), - parameter_values=pe_param_values, - ) - - # Run simulation - sim = Simulation( - dataset=dataset, - tax_benefit_model_version=us_latest, - policy=policy, - ) - sim.ensure() - - # Extract results - result = { - "person": [], - "tax_unit": [], - "family": [], - "spm_unit": [], - "marital_unit": [], - "household": [], - } - - for i in range(n_people): - person_result = {} - for var in sim.output_dataset.person.columns: - val = sim.output_dataset.person[var].iloc[i] - person_result[var] = float(val) if hasattr(val, "item") else val - result["person"].append(person_result) - - for i in range(n_tax_units): - tu_result = {} - for var in sim.output_dataset.tax_unit.columns: - val = sim.output_dataset.tax_unit[var].iloc[i] - tu_result[var] = float(val) if hasattr(val, "item") else val - result["tax_unit"].append(tu_result) - - for i in range(n_families): - fam_result = {} - for var in sim.output_dataset.family.columns: - val = sim.output_dataset.family[var].iloc[i] - fam_result[var] = float(val) if hasattr(val, "item") else val - result["family"].append(fam_result) - - for i in range(n_spm_units): - spm_result = {} - for var in sim.output_dataset.spm_unit.columns: - val = sim.output_dataset.spm_unit[var].iloc[i] - spm_result[var] = float(val) if hasattr(val, "item") else val - result["spm_unit"].append(spm_result) - - for i in range(n_marital_units): - mu_result = {} - for var in sim.output_dataset.marital_unit.columns: - val = sim.output_dataset.marital_unit[var].iloc[i] - mu_result[var] = float(val) if hasattr(val, "item") else val - result["marital_unit"].append(mu_result) - - for i in range(n_households): - hh_result = {} - for var in sim.output_dataset.household.columns: - val = sim.output_dataset.household[var].iloc[i] - hh_result[var] = float(val) if hasattr(val, "item") else val - result["household"].append(hh_result) - - return result - - -def _get_household_policy_data(policy_id, session) -> dict | None: - """Get policy data for household calculation.""" - if policy_id is None: - return None - - from policyengine_api.models import Policy - - db_policy = session.get(Policy, policy_id) - if not db_policy: - return None - - return { - "name": db_policy.name, - "description": db_policy.description, - "parameter_values": [ - { - "parameter_name": pv.parameter.name if pv.parameter else None, - "value": pv.value_json.get("value") - if isinstance(pv.value_json, dict) - else pv.value_json, - "start_date": pv.start_date.isoformat() if pv.start_date else None, - "end_date": pv.end_date.isoformat() if pv.end_date else None, - } - for pv in db_policy.parameter_values - if pv.parameter - ], - } diff --git a/test_fixtures/fixtures_policy_reform.py b/test_fixtures/fixtures_policy_reform.py deleted file mode 100644 index f7534a5..0000000 --- a/test_fixtures/fixtures_policy_reform.py +++ /dev/null @@ -1,282 +0,0 @@ -"""Fixtures for policy reform conversion tests.""" - -from dataclasses import dataclass -from datetime import date, datetime -from typing import Any - - -# ============================================================================= -# Mock objects for testing _pe_policy_to_reform_dict -# ============================================================================= - - -@dataclass -class MockParameter: - """Mock policyengine.core.models.parameter.Parameter.""" - - name: str - - -@dataclass -class MockParameterValue: - """Mock policyengine.core.models.parameter_value.ParameterValue.""" - - parameter: MockParameter | None - value: Any - start_date: date | datetime | str | None - - -@dataclass -class MockPolicy: - """Mock policyengine.core.policy.Policy.""" - - parameter_values: list[MockParameterValue] | None - - -# ============================================================================= -# Test data constants -# ============================================================================= - -# Simple policy with single parameter change -SIMPLE_POLICY = MockPolicy( - parameter_values=[ - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=3000, - start_date=date(2024, 1, 1), - ) - ] -) - -SIMPLE_POLICY_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} -} - -# Policy with multiple parameter changes -MULTI_PARAM_POLICY = MockPolicy( - parameter_values=[ - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=3000, - start_date=date(2024, 1, 1), - ), - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.refundable.fully_refundable"), - value=True, - start_date=date(2024, 1, 1), - ), - MockParameterValue( - parameter=MockParameter(name="gov.irs.income.bracket.rates.1"), - value=0.12, - start_date=date(2024, 1, 1), - ), - ] -) - -MULTI_PARAM_POLICY_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, - "gov.irs.credits.ctc.refundable.fully_refundable": {"2024-01-01": True}, - "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.12}, -} - -# Policy with same parameter at different dates -MULTI_DATE_POLICY = MockPolicy( - parameter_values=[ - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=2500, - start_date=date(2024, 1, 1), - ), - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=3000, - start_date=date(2025, 1, 1), - ), - ] -) - -MULTI_DATE_POLICY_EXPECTED = { - "gov.irs.credits.ctc.amount.base": { - "2024-01-01": 2500, - "2025-01-01": 3000, - } -} - -# Policy with datetime start_date (has time component) -DATETIME_POLICY = MockPolicy( - parameter_values=[ - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=3000, - start_date=datetime(2024, 1, 1, 12, 30, 45), - ) - ] -) - -DATETIME_POLICY_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} -} - -# Policy with ISO string start_date -ISO_STRING_POLICY = MockPolicy( - parameter_values=[ - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=3000, - start_date="2024-01-01T00:00:00", - ) - ] -) - -ISO_STRING_POLICY_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000} -} - -# Empty policy (no parameter values) -EMPTY_POLICY = MockPolicy(parameter_values=[]) - -# None policy -NONE_POLICY = None - -# Policy with None parameter_values -NONE_PARAM_VALUES_POLICY = MockPolicy(parameter_values=None) - -# Policy with invalid entries (missing parameter or start_date) -INVALID_ENTRIES_POLICY = MockPolicy( - parameter_values=[ - MockParameterValue( - parameter=None, # Missing parameter - value=3000, - start_date=date(2024, 1, 1), - ), - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.ctc.amount.base"), - value=3000, - start_date=None, # Missing start_date - ), - MockParameterValue( - parameter=MockParameter(name="gov.irs.credits.eitc.max.0"), - value=600, - start_date=date(2024, 1, 1), # This one is valid - ), - ] -) - -INVALID_ENTRIES_POLICY_EXPECTED = { - "gov.irs.credits.eitc.max.0": {"2024-01-01": 600} -} - - -# ============================================================================= -# Test data for _merge_reform_dicts -# ============================================================================= - -REFORM_DICT_1 = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 2000}, - "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.10}, -} - -REFORM_DICT_2 = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, # Overwrites - "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, # New param -} - -MERGED_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, # From reform2 - "gov.irs.income.bracket.rates.1": {"2024-01-01": 0.10}, # From reform1 - "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, # From reform2 -} - -REFORM_DICT_3 = { - "gov.irs.credits.ctc.amount.base": { - "2024-01-01": 2500, - "2025-01-01": 2700, - }, -} - -REFORM_DICT_4 = { - "gov.irs.credits.ctc.amount.base": { - "2025-01-01": 3000, # Overwrites 2025 date - "2026-01-01": 3500, # New date - }, -} - -MERGED_MULTI_DATE_EXPECTED = { - "gov.irs.credits.ctc.amount.base": { - "2024-01-01": 2500, # From reform3 - "2025-01-01": 3000, # From reform4 (overwrites) - "2026-01-01": 3500, # From reform4 (new) - }, -} - - -# ============================================================================= -# Test data for household calculation policy conversion -# ============================================================================= - -# Policy data as it comes from the API (stored in database) -HOUSEHOLD_POLICY_DATA = { - "parameter_values": [ - { - "parameter_name": "gov.irs.credits.ctc.amount.base", - "value": 3000, - "start_date": "2024-01-01", - }, - { - "parameter_name": "gov.irs.credits.ctc.refundable.fully_refundable", - "value": True, - "start_date": "2024-01-01", - }, - ] -} - -HOUSEHOLD_POLICY_DATA_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, - "gov.irs.credits.ctc.refundable.fully_refundable": {"2024-01-01": True}, -} - -# Policy data with ISO datetime strings -HOUSEHOLD_POLICY_DATA_DATETIME = { - "parameter_values": [ - { - "parameter_name": "gov.irs.credits.ctc.amount.base", - "value": 3000, - "start_date": "2024-01-01T00:00:00.000Z", - }, - ] -} - -HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED = { - "gov.irs.credits.ctc.amount.base": {"2024-01-01": 3000}, -} - -# Empty policy data -HOUSEHOLD_EMPTY_POLICY_DATA = {"parameter_values": []} - -# None policy data -HOUSEHOLD_NONE_POLICY_DATA = None - -# Policy data with missing fields -HOUSEHOLD_INCOMPLETE_POLICY_DATA = { - "parameter_values": [ - { - "parameter_name": None, # Missing - "value": 3000, - "start_date": "2024-01-01", - }, - { - "parameter_name": "gov.irs.credits.ctc.amount.base", - "value": 3000, - "start_date": None, # Missing - }, - { - "parameter_name": "gov.irs.credits.eitc.max.0", - "value": 600, - "start_date": "2024-01-01", # Valid - }, - ] -} - -HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED = { - "gov.irs.credits.eitc.max.0": {"2024-01-01": 600}, -} diff --git a/tests/test_policy_reform.py b/tests/test_policy_reform.py deleted file mode 100644 index cfee3b8..0000000 --- a/tests/test_policy_reform.py +++ /dev/null @@ -1,327 +0,0 @@ -"""Tests for policy reform conversion logic. - -Tests the helper functions that convert policy objects to reform dict format -for use with Microsimulation. These are critical for fixing the bug where -reforms weren't being applied to economy-wide and household simulations. -""" - -import sys -from unittest.mock import MagicMock - -import pytest - -# Mock modal before importing modal_app -sys.modules["modal"] = MagicMock() - -from test_fixtures.fixtures_policy_reform import ( - DATETIME_POLICY, - DATETIME_POLICY_EXPECTED, - EMPTY_POLICY, - HOUSEHOLD_EMPTY_POLICY_DATA, - HOUSEHOLD_INCOMPLETE_POLICY_DATA, - HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED, - HOUSEHOLD_NONE_POLICY_DATA, - HOUSEHOLD_POLICY_DATA, - HOUSEHOLD_POLICY_DATA_DATETIME, - HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED, - HOUSEHOLD_POLICY_DATA_EXPECTED, - INVALID_ENTRIES_POLICY, - INVALID_ENTRIES_POLICY_EXPECTED, - ISO_STRING_POLICY, - ISO_STRING_POLICY_EXPECTED, - MERGED_EXPECTED, - MERGED_MULTI_DATE_EXPECTED, - MULTI_DATE_POLICY, - MULTI_DATE_POLICY_EXPECTED, - MULTI_PARAM_POLICY, - MULTI_PARAM_POLICY_EXPECTED, - NONE_PARAM_VALUES_POLICY, - NONE_POLICY, - REFORM_DICT_1, - REFORM_DICT_2, - REFORM_DICT_3, - REFORM_DICT_4, - SIMPLE_POLICY, - SIMPLE_POLICY_EXPECTED, -) - -# Import after mocking modal -from policyengine_api.modal_app import _merge_reform_dicts, _pe_policy_to_reform_dict - - -class TestPePolicyToReformDict: - """Tests for _pe_policy_to_reform_dict function.""" - - # ========================================================================= - # Given: Valid policy with single parameter - # ========================================================================= - - def test__given_simple_policy_with_date_object__then_returns_correct_reform_dict( - self, - ): - """Given a policy with a single parameter using date object, - then returns correctly formatted reform dict.""" - # When - result = _pe_policy_to_reform_dict(SIMPLE_POLICY) - - # Then - assert result == SIMPLE_POLICY_EXPECTED - - def test__given_policy_with_datetime_object__then_extracts_date_correctly(self): - """Given a policy with datetime start_date (has time component), - then extracts just the date part for the reform dict.""" - # When - result = _pe_policy_to_reform_dict(DATETIME_POLICY) - - # Then - assert result == DATETIME_POLICY_EXPECTED - - def test__given_policy_with_iso_string_date__then_parses_date_correctly(self): - """Given a policy with ISO string start_date, - then parses and extracts the date correctly.""" - # When - result = _pe_policy_to_reform_dict(ISO_STRING_POLICY) - - # Then - assert result == ISO_STRING_POLICY_EXPECTED - - # ========================================================================= - # Given: Policy with multiple parameters - # ========================================================================= - - def test__given_policy_with_multiple_parameters__then_includes_all_in_dict(self): - """Given a policy with multiple parameter changes, - then includes all parameters in the reform dict.""" - # When - result = _pe_policy_to_reform_dict(MULTI_PARAM_POLICY) - - # Then - assert result == MULTI_PARAM_POLICY_EXPECTED - - def test__given_policy_with_same_param_multiple_dates__then_includes_all_dates( - self, - ): - """Given a policy with the same parameter changed at different dates, - then includes all date entries for that parameter.""" - # When - result = _pe_policy_to_reform_dict(MULTI_DATE_POLICY) - - # Then - assert result == MULTI_DATE_POLICY_EXPECTED - - # ========================================================================= - # Given: Empty or None policy - # ========================================================================= - - def test__given_none_policy__then_returns_none(self): - """Given None as policy, - then returns None.""" - # When - result = _pe_policy_to_reform_dict(NONE_POLICY) - - # Then - assert result is None - - def test__given_policy_with_empty_parameter_values__then_returns_none(self): - """Given a policy with empty parameter_values list, - then returns None.""" - # When - result = _pe_policy_to_reform_dict(EMPTY_POLICY) - - # Then - assert result is None - - def test__given_policy_with_none_parameter_values__then_returns_none(self): - """Given a policy with parameter_values=None, - then returns None.""" - # When - result = _pe_policy_to_reform_dict(NONE_PARAM_VALUES_POLICY) - - # Then - assert result is None - - # ========================================================================= - # Given: Policy with invalid entries - # ========================================================================= - - def test__given_policy_with_invalid_entries__then_skips_invalid_keeps_valid(self): - """Given a policy with some invalid entries (missing parameter or date), - then skips invalid entries and keeps valid ones.""" - # When - result = _pe_policy_to_reform_dict(INVALID_ENTRIES_POLICY) - - # Then - assert result == INVALID_ENTRIES_POLICY_EXPECTED - - -class TestMergeReformDicts: - """Tests for _merge_reform_dicts function.""" - - # ========================================================================= - # Given: Two valid reform dicts - # ========================================================================= - - def test__given_two_reform_dicts__then_merges_with_second_taking_precedence(self): - """Given two reform dicts with overlapping parameters, - then merges them with the second dict taking precedence.""" - # When - result = _merge_reform_dicts(REFORM_DICT_1, REFORM_DICT_2) - - # Then - assert result == MERGED_EXPECTED - - def test__given_dicts_with_multiple_dates__then_merges_date_entries_correctly(self): - """Given reform dicts with same parameter at multiple dates, - then merges date entries correctly with second taking precedence.""" - # When - result = _merge_reform_dicts(REFORM_DICT_3, REFORM_DICT_4) - - # Then - assert result == MERGED_MULTI_DATE_EXPECTED - - # ========================================================================= - # Given: None values - # ========================================================================= - - def test__given_both_none__then_returns_none(self): - """Given both reform dicts are None, - then returns None.""" - # When - result = _merge_reform_dicts(None, None) - - # Then - assert result is None - - def test__given_first_none__then_returns_second(self): - """Given first reform dict is None, - then returns the second dict.""" - # When - result = _merge_reform_dicts(None, REFORM_DICT_1) - - # Then - assert result == REFORM_DICT_1 - - def test__given_second_none__then_returns_first(self): - """Given second reform dict is None, - then returns the first dict.""" - # When - result = _merge_reform_dicts(REFORM_DICT_1, None) - - # Then - assert result == REFORM_DICT_1 - - # ========================================================================= - # Given: Original dict should not be mutated - # ========================================================================= - - def test__given_two_dicts__then_does_not_mutate_original_dicts(self): - """Given two reform dicts, - then merging does not mutate the original dicts.""" - # Given - original_dict1 = {"param.a": {"2024-01-01": 100}} - original_dict2 = {"param.b": {"2024-01-01": 200}} - dict1_copy = dict(original_dict1) - dict2_copy = dict(original_dict2) - - # When - _merge_reform_dicts(original_dict1, original_dict2) - - # Then - assert original_dict1 == dict1_copy - assert original_dict2 == dict2_copy - - -class TestHouseholdPolicyDataConversion: - """Tests for the policy data conversion logic used in household calculations. - - This tests the conversion logic as it appears in _calculate_household_us - and _calculate_household_uk functions. - """ - - def _convert_policy_data_to_reform(self, policy_data: dict | None) -> dict | None: - """Convert policy_data (from API) to reform dict format. - - This mirrors the conversion logic in _calculate_household_us. - """ - if not policy_data or not policy_data.get("parameter_values"): - return None - - reform = {} - for pv in policy_data["parameter_values"]: - param_name = pv.get("parameter_name") - value = pv.get("value") - start_date = pv.get("start_date") - - if param_name and start_date: - # Parse ISO date string to get just the date part - if "T" in start_date: - date_str = start_date.split("T")[0] - else: - date_str = start_date - - if param_name not in reform: - reform[param_name] = {} - reform[param_name][date_str] = value - - return reform if reform else None - - # ========================================================================= - # Given: Valid policy data from API - # ========================================================================= - - def test__given_valid_policy_data__then_converts_to_reform_dict(self): - """Given valid policy data from the API, - then converts it to the correct reform dict format.""" - # When - result = self._convert_policy_data_to_reform(HOUSEHOLD_POLICY_DATA) - - # Then - assert result == HOUSEHOLD_POLICY_DATA_EXPECTED - - def test__given_policy_data_with_datetime_strings__then_extracts_date_part(self): - """Given policy data with ISO datetime strings (with T and timezone), - then extracts just the date part.""" - # When - result = self._convert_policy_data_to_reform(HOUSEHOLD_POLICY_DATA_DATETIME) - - # Then - assert result == HOUSEHOLD_POLICY_DATA_DATETIME_EXPECTED - - # ========================================================================= - # Given: Empty or None policy data - # ========================================================================= - - def test__given_none_policy_data__then_returns_none(self): - """Given None policy data, - then returns None.""" - # When - result = self._convert_policy_data_to_reform(HOUSEHOLD_NONE_POLICY_DATA) - - # Then - assert result is None - - def test__given_empty_parameter_values__then_returns_none(self): - """Given policy data with empty parameter_values list, - then returns None.""" - # When - result = self._convert_policy_data_to_reform(HOUSEHOLD_EMPTY_POLICY_DATA) - - # Then - assert result is None - - # ========================================================================= - # Given: Incomplete policy data - # ========================================================================= - - def test__given_incomplete_entries__then_skips_invalid_keeps_valid(self): - """Given policy data with some entries missing required fields, - then skips invalid entries and keeps valid ones.""" - # When - result = self._convert_policy_data_to_reform(HOUSEHOLD_INCOMPLETE_POLICY_DATA) - - # Then - assert result == HOUSEHOLD_INCOMPLETE_POLICY_DATA_EXPECTED - - -if __name__ == "__main__": - pytest.main([__file__, "-v"]) From 6e9ca4121b3f295a5263bca040863bb1d51ec48f Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Wed, 11 Feb 2026 23:12:13 +0100 Subject: [PATCH 16/19] test: Add tests for US policy reform application MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add tests to verify that US policy reforms are applied correctly to household calculations. These tests cover: - Integration tests via API endpoints (TestUSPolicyReform, TestUKPolicyReform) - Unit tests for the calculation functions directly (test_household_calculation.py) The tests verify: 1. Baseline calculations work correctly 2. Reforms change household net income as expected 3. Running a reform doesn't pollute subsequent baseline calculations (regression test for the singleton pollution bug) šŸ¤– Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- tests/test_household.py | 197 ++++++++++++++++++++++++++++ tests/test_household_calculation.py | 128 ++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 tests/test_household_calculation.py diff --git a/tests/test_household.py b/tests/test_household.py index a7248b3..eab15a5 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -289,5 +289,202 @@ def test_missing_people(self): assert response.status_code == 422 +class TestUSPolicyReform: + """Tests for US household calculations with policy reforms.""" + + def _get_us_model_id(self) -> str: + """Get the US tax benefit model ID.""" + response = client.get("/tax-benefit-models/") + assert response.status_code == 200 + models = response.json() + for model in models: + if "us" in model["name"].lower(): + return model["id"] + raise AssertionError("US model not found") + + def _get_parameter_id(self, model_id: str, param_name: str) -> str: + """Get a parameter ID by name.""" + response = client.get( + f"/parameters/?tax_benefit_model_id={model_id}&limit=10000" + ) + assert response.status_code == 200 + params = response.json() + for param in params: + if param["name"] == param_name: + return param["id"] + raise AssertionError(f"Parameter {param_name} not found") + + def _create_policy(self, param_id: str, value: float) -> str: + """Create a policy with a parameter value.""" + response = client.post( + "/policies/", + json={ + "name": "Test Reform", + "description": "Test reform for household calculation", + "parameter_values": [ + { + "parameter_id": param_id, + "value_json": value, + "start_date": "2024-01-01T00:00:00Z", + } + ], + }, + ) + assert response.status_code == 200 + return response.json()["id"] + + def test_us_reform_changes_household_net_income(self): + """Test that a US policy reform changes household net income. + + This test verifies the fix for the US reform application bug where + reforms were not being applied correctly due to the shared singleton + TaxBenefitSystem in policyengine-us. + """ + # Get the US model and a UBI parameter + model_id = self._get_us_model_id() + param_name = "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount" + param_id = self._get_parameter_id(model_id, param_name) + + # Create a policy with $1000 UBI for older adults + policy_id = self._create_policy(param_id, 1000) + + # Run baseline calculation (no policy) + baseline_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 40, "employment_income": 70000}], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + }, + ) + assert baseline_response.status_code == 200 + baseline_data = _poll_job(baseline_response.json()["job_id"]) + baseline_net_income = baseline_data["result"]["household"][0][ + "household_net_income" + ] + + # Run reform calculation (with UBI policy) + reform_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 40, "employment_income": 70000}], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + "policy_id": policy_id, + }, + ) + assert reform_response.status_code == 200 + reform_data = _poll_job(reform_response.json()["job_id"]) + reform_net_income = reform_data["result"]["household"][0][ + "household_net_income" + ] + + # Verify the reform increased net income by approximately $1000 + difference = reform_net_income - baseline_net_income + assert abs(difference - 1000) < 1, ( + f"Expected ~$1000 difference, got ${difference:.2f}. " + f"Baseline: ${baseline_net_income:.2f}, Reform: ${reform_net_income:.2f}" + ) + + +class TestUKPolicyReform: + """Tests for UK household calculations with policy reforms.""" + + def _get_uk_model_id(self) -> str | None: + """Get the UK tax benefit model ID, or None if not seeded.""" + response = client.get("/tax-benefit-models/") + assert response.status_code == 200 + models = response.json() + for model in models: + if "uk" in model["name"].lower(): + return model["id"] + return None + + def _get_parameter_id(self, model_id: str, param_name: str) -> str: + """Get a parameter ID by name.""" + response = client.get( + f"/parameters/?tax_benefit_model_id={model_id}&limit=10000" + ) + assert response.status_code == 200 + params = response.json() + for param in params: + if param["name"] == param_name: + return param["id"] + raise AssertionError(f"Parameter {param_name} not found") + + def _create_policy(self, param_id: str, value: float) -> str: + """Create a policy with a parameter value.""" + response = client.post( + "/policies/", + json={ + "name": "Test UK Reform", + "description": "Test reform for UK household calculation", + "parameter_values": [ + { + "parameter_id": param_id, + "value_json": value, + "start_date": "2026-01-01T00:00:00Z", + } + ], + }, + ) + assert response.status_code == 200 + return response.json()["id"] + + def test_uk_reform_changes_household_net_income(self): + """Test that a UK policy reform changes household net income.""" + # Get the UK model and a UBI parameter + model_id = self._get_uk_model_id() + if model_id is None: + pytest.skip("UK model not seeded in database") + param_name = "gov.contrib.ubi_center.basic_income.adult" + param_id = self._get_parameter_id(model_id, param_name) + + # Create a policy with Ā£1000 UBI for adults + policy_id = self._create_policy(param_id, 1000) + + # Run baseline calculation (no policy) + baseline_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + }, + ) + assert baseline_response.status_code == 200 + baseline_data = _poll_job(baseline_response.json()["job_id"]) + baseline_net_income = baseline_data["result"]["household"][0][ + "household_net_income" + ] + + # Run reform calculation (with UBI policy) + reform_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + "policy_id": policy_id, + }, + ) + assert reform_response.status_code == 200 + reform_data = _poll_job(reform_response.json()["job_id"]) + reform_net_income = reform_data["result"]["household"][0][ + "household_net_income" + ] + + # Verify the reform increased net income + difference = reform_net_income - baseline_net_income + assert difference > 0, ( + f"Expected positive difference, got Ā£{difference:.2f}. " + f"Baseline: Ā£{baseline_net_income:.2f}, Reform: Ā£{reform_net_income:.2f}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_household_calculation.py b/tests/test_household_calculation.py new file mode 100644 index 0000000..e4fc2a5 --- /dev/null +++ b/tests/test_household_calculation.py @@ -0,0 +1,128 @@ +"""Unit tests for household calculation functions. + +These tests verify that the calculation functions work correctly with policy reforms, +without requiring database setup or API calls. +""" + +import pytest + +from policyengine_api.api.household import _calculate_household_us + + +class TestUSHouseholdCalculation: + """Unit tests for US household calculation with policy reforms.""" + + @pytest.mark.slow + def test_baseline_calculation(self): + """Test basic US household calculation without policy.""" + result = _calculate_household_us( + people=[{"employment_income": 70000, "age": 40}], + marital_unit=[], + family=[], + spm_unit=[], + tax_unit=[{"state_code": "CA"}], + household=[{"state_fips": 6}], + year=2024, + policy_data=None, + ) + + assert "person" in result + assert "household" in result + assert "tax_unit" in result + assert len(result["person"]) == 1 + assert result["tax_unit"][0]["income_tax"] > 0 + + @pytest.mark.slow + def test_reform_changes_net_income(self): + """Test that a US policy reform changes household net income. + + This test verifies the fix for the US reform application bug where + reforms were not being applied correctly due to the shared singleton + TaxBenefitSystem in policyengine-us. + """ + household_args = { + "people": [{"employment_income": 70000, "age": 40}], + "marital_unit": [], + "family": [], + "spm_unit": [], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + } + + # Calculate baseline (no policy) + baseline = _calculate_household_us(**household_args, policy_data=None) + baseline_net_income = baseline["household"][0]["household_net_income"] + + # Calculate with $1000 UBI reform + policy_data = { + "name": "Test UBI", + "description": "Test UBI reform", + "parameter_values": [ + { + "parameter_name": "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount", + "value": 1000, + "start_date": "2024-01-01T00:00:00", + "end_date": None, + } + ], + } + reform = _calculate_household_us(**household_args, policy_data=policy_data) + reform_net_income = reform["household"][0]["household_net_income"] + + # Verify the reform increased net income by exactly $1000 + difference = reform_net_income - baseline_net_income + assert abs(difference - 1000) < 1, ( + f"Expected ~$1000 difference, got ${difference:.2f}. " + f"Baseline: ${baseline_net_income:.2f}, Reform: ${reform_net_income:.2f}" + ) + + @pytest.mark.slow + def test_reform_does_not_affect_baseline(self): + """Test that running reform doesn't pollute baseline calculations. + + This is a regression test for the singleton pollution bug where running + a reform calculation would affect subsequent baseline calculations. + """ + household_args = { + "people": [{"employment_income": 70000, "age": 40}], + "marital_unit": [], + "family": [], + "spm_unit": [], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + } + + # First baseline + baseline1 = _calculate_household_us(**household_args, policy_data=None) + baseline1_net_income = baseline1["household"][0]["household_net_income"] + + # Run reform + policy_data = { + "name": "Test UBI", + "description": "Test UBI reform", + "parameter_values": [ + { + "parameter_name": "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount", + "value": 5000, + "start_date": "2024-01-01T00:00:00", + "end_date": None, + } + ], + } + _calculate_household_us(**household_args, policy_data=policy_data) + + # Second baseline - should be same as first + baseline2 = _calculate_household_us(**household_args, policy_data=None) + baseline2_net_income = baseline2["household"][0]["household_net_income"] + + # Verify baselines are identical + assert abs(baseline1_net_income - baseline2_net_income) < 0.01, ( + f"Baseline changed after reform calculation! " + f"Before: ${baseline1_net_income:.2f}, After: ${baseline2_net_income:.2f}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 43135e26c25f348e5a82bf4688da768eea0fe4a6 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Mon, 16 Feb 2026 23:26:04 +0100 Subject: [PATCH 17/19] fix: Install policyengine.py from app-v2-migration branch The PyPI release (3.1.15) has a bug where US reforms silently fail due to the shared singleton TaxBenefitSystem (policyengine.py#232). The fix exists on the app-v2-migration branch but hasn't been released. Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1fe9093..91ec88a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,11 @@ dependencies = [ "psycopg2-binary>=2.9.10", "supabase>=2.10.0", "storage3>=0.8.1", - "policyengine>=3.1.15", + # IMPORTANT: Before merging app-v2-migration into main, replace this git + # dependency with the production PyPI version of policyengine (e.g., "policyengine>=X.Y.Z"). + # The git ref is used here because the app-v2-migration branch contains fixes + # (US reform application, regions support) not yet released to PyPI. + "policyengine @ git+https://github.com/PolicyEngine/policyengine.py.git@app-v2-migration", "policyengine-uk>=2.0.0", "policyengine-us>=1.0.0", "pydantic>=2.9.2", From 6e77d4d8f224970a12646e9921c74bc162a9225b Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Feb 2026 00:12:04 +0100 Subject: [PATCH 18/19] fix: Allow direct references in hatch metadata Hatchling rejects git URL dependencies unless explicitly opted in. Co-Authored-By: Claude Opus 4.6 --- pyproject.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 91ec88a..d624a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ dev = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["src/policyengine_api"] From 754126e06d1a7b3d4c414121cc41d610210d51a0 Mon Sep 17 00:00:00 2001 From: Anthony Volk Date: Tue, 17 Feb 2026 00:33:54 +0100 Subject: [PATCH 19/19] chore: Remove test scripts, Nevada seed, and archived Supabase migrations Test scripts in scripts/ were ad-hoc debugging aids, not part of the test suite. Nevada seed is no longer needed. Archived Supabase migrations are superseded by Alembic. Co-Authored-By: Claude Opus 4.6 --- scripts/seed_nevada.py | 128 ------- scripts/test_economy_simulation.py | 277 -------------- scripts/test_household_impact.py | 135 ------- scripts/test_household_scenarios.py | 344 ------------------ ...229000000_add_parameter_values_indexes.sql | 16 - .../20260103000000_add_poverty_inequality.sql | 33 -- .../20260111000000_add_aggregate_status.sql | 13 - .../20260203000000_create_households.sql | 14 - ...001_create_user_household_associations.sql | 14 - ...203000002_simulation_household_support.sql | 16 - 10 files changed, 990 deletions(-) delete mode 100644 scripts/seed_nevada.py delete mode 100644 scripts/test_economy_simulation.py delete mode 100644 scripts/test_household_impact.py delete mode 100644 scripts/test_household_scenarios.py delete mode 100644 supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql delete mode 100644 supabase/migrations_archived/20260103000000_add_poverty_inequality.sql delete mode 100644 supabase/migrations_archived/20260111000000_add_aggregate_status.sql delete mode 100644 supabase/migrations_archived/20260203000000_create_households.sql delete mode 100644 supabase/migrations_archived/20260203000001_create_user_household_associations.sql delete mode 100644 supabase/migrations_archived/20260203000002_simulation_household_support.sql diff --git a/scripts/seed_nevada.py b/scripts/seed_nevada.py deleted file mode 100644 index 0af2cb4..0000000 --- a/scripts/seed_nevada.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Seed Nevada datasets into local Supabase. - -This script seeds pre-created Nevada state and congressional district datasets -into the local Supabase database for testing purposes. - -Usage: - uv run python scripts/seed_nevada.py -""" - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent.parent / "src")) - -from rich.console import Console -from sqlmodel import Session, create_engine, select - -from policyengine_api.config.settings import settings -from policyengine_api.models import Dataset, TaxBenefitModel -from policyengine_api.services.storage import upload_dataset_for_seeding - -console = Console() - -# Nevada datasets location -NEVADA_DATA_DIR = Path(__file__).parent.parent / "test_data" / "nevada_datasets" - - -def main(): - """Seed Nevada datasets.""" - console.print("[bold blue]Seeding Nevada datasets for testing...") - - engine = create_engine(settings.database_url, echo=False) - - with Session(engine) as session: - # Get or create US model - us_model = session.exec( - select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") - ).first() - - if not us_model: - console.print(" Creating US tax-benefit model...") - us_model = TaxBenefitModel( - name="policyengine-us", - description="US tax-benefit system model", - ) - session.add(us_model) - session.commit() - session.refresh(us_model) - console.print(" [green]āœ“[/green] Created policyengine-us model") - - # Seed state datasets - states_dir = NEVADA_DATA_DIR / "states" - if states_dir.exists(): - console.print("\n [bold]Nevada State Datasets:[/bold]") - for h5_file in sorted(states_dir.glob("*.h5")): - name = h5_file.stem # e.g., "NV_year_2024" - year = int(name.split("_")[-1]) - - # Check if already exists - existing = session.exec( - select(Dataset).where(Dataset.name == name) - ).first() - - if existing: - console.print(f" [yellow]ā­[/yellow] {name} (already exists)") - continue - - # Upload to storage - console.print(f" Uploading {name}...", end=" ") - try: - object_name = upload_dataset_for_seeding(str(h5_file)) - - # Create database record - db_dataset = Dataset( - name=name, - description=f"Nevada state dataset for year {year}", - filepath=object_name, - year=year, - tax_benefit_model_id=us_model.id, - ) - session.add(db_dataset) - session.commit() - console.print("[green]āœ“[/green]") - except Exception as e: - console.print(f"[red]āœ— {e}[/red]") - - # Seed district datasets - districts_dir = NEVADA_DATA_DIR / "districts" - if districts_dir.exists(): - console.print("\n [bold]Nevada Congressional District Datasets:[/bold]") - for h5_file in sorted(districts_dir.glob("*.h5")): - name = h5_file.stem # e.g., "NV-01_year_2024" - year = int(name.split("_")[-1]) - district = name.split("_")[0] # e.g., "NV-01" - - # Check if already exists - existing = session.exec( - select(Dataset).where(Dataset.name == name) - ).first() - - if existing: - console.print(f" [yellow]ā­[/yellow] {name} (already exists)") - continue - - # Upload to storage - console.print(f" Uploading {name}...", end=" ") - try: - object_name = upload_dataset_for_seeding(str(h5_file)) - - # Create database record - db_dataset = Dataset( - name=name, - description=f"{district} congressional district dataset for year {year}", - filepath=object_name, - year=year, - tax_benefit_model_id=us_model.id, - ) - session.add(db_dataset) - session.commit() - console.print("[green]āœ“[/green]") - except Exception as e: - console.print(f"[red]āœ— {e}[/red]") - - console.print("\n[bold green]āœ“ Nevada datasets seeded successfully![/bold green]") - - -if __name__ == "__main__": - main() diff --git a/scripts/test_economy_simulation.py b/scripts/test_economy_simulation.py deleted file mode 100644 index 3845fc4..0000000 --- a/scripts/test_economy_simulation.py +++ /dev/null @@ -1,277 +0,0 @@ -"""Test economy-wide simulation following the exact flow from modal_app.py. - -This script mimics the economy-wide simulation code path as closely as possible -to verify whether policy reforms are being applied correctly. -""" - -import tempfile -from datetime import datetime -from pathlib import Path - -import pandas as pd -from microdf import MicroDataFrame - -# Import exactly as modal_app.py does -from policyengine.core import Simulation as PESimulation -from policyengine.core.policy import ParameterValue as PEParameterValue -from policyengine.core.policy import Policy as PEPolicy -from policyengine.tax_benefit_models.us import us_latest -from policyengine.tax_benefit_models.us.datasets import PolicyEngineUSDataset, USYearData - - -def create_test_dataset(year: int) -> PolicyEngineUSDataset: - """Create a small test dataset similar to what would be loaded from storage. - - Uses the same structure as economy-wide datasets but with just a few households. - """ - # Create 3 test households with different income levels - # Each household has 2 adults + 2 children (to test CTC) - n_households = 3 - n_people = n_households * 4 # 4 people per household - - # Person data - person_data = { - "person_id": list(range(n_people)), - "person_household_id": [i // 4 for i in range(n_people)], - "person_marital_unit_id": [], - "person_family_id": [i // 4 for i in range(n_people)], - "person_spm_unit_id": [i // 4 for i in range(n_people)], - "person_tax_unit_id": [i // 4 for i in range(n_people)], - "person_weight": [1000.0] * n_people, # Weight for population scaling - "age": [], - "employment_income": [], - } - - # Build person details - marital_unit_counter = 0 - for hh in range(n_households): - base_income = 10000 + (hh * 20000) # 10k, 30k, 50k - # Adult 1 - person_data["age"].append(35) - person_data["employment_income"].append(base_income) - person_data["person_marital_unit_id"].append(marital_unit_counter) - # Adult 2 - person_data["age"].append(33) - person_data["employment_income"].append(0) - person_data["person_marital_unit_id"].append(marital_unit_counter) - marital_unit_counter += 1 - # Child 1 - person_data["age"].append(5) - person_data["employment_income"].append(0) - person_data["person_marital_unit_id"].append(marital_unit_counter) - marital_unit_counter += 1 - # Child 2 - person_data["age"].append(3) - person_data["employment_income"].append(0) - person_data["person_marital_unit_id"].append(marital_unit_counter) - marital_unit_counter += 1 - - n_marital_units = marital_unit_counter - - # Entity data - household_data = { - "household_id": list(range(n_households)), - "household_weight": [1000.0] * n_households, - "state_fips": [48] * n_households, # Texas - } - - marital_unit_data = { - "marital_unit_id": list(range(n_marital_units)), - "marital_unit_weight": [1000.0] * n_marital_units, - } - - family_data = { - "family_id": list(range(n_households)), - "family_weight": [1000.0] * n_households, - } - - spm_unit_data = { - "spm_unit_id": list(range(n_households)), - "spm_unit_weight": [1000.0] * n_households, - } - - tax_unit_data = { - "tax_unit_id": list(range(n_households)), - "tax_unit_weight": [1000.0] * n_households, - } - - # Create MicroDataFrames (same as economy datasets) - person_df = MicroDataFrame(pd.DataFrame(person_data), weights="person_weight") - household_df = MicroDataFrame(pd.DataFrame(household_data), weights="household_weight") - marital_unit_df = MicroDataFrame(pd.DataFrame(marital_unit_data), weights="marital_unit_weight") - family_df = MicroDataFrame(pd.DataFrame(family_data), weights="family_weight") - spm_unit_df = MicroDataFrame(pd.DataFrame(spm_unit_data), weights="spm_unit_weight") - tax_unit_df = MicroDataFrame(pd.DataFrame(tax_unit_data), weights="tax_unit_weight") - - # Create dataset file - tmpdir = tempfile.mkdtemp() - filepath = str(Path(tmpdir) / "test_economy.h5") - - return PolicyEngineUSDataset( - name="Test Economy Dataset", - description="Small test dataset for economy simulation", - filepath=filepath, - year=year, - data=USYearData( - person=person_df, - household=household_df, - marital_unit=marital_unit_df, - family=family_df, - spm_unit=spm_unit_df, - tax_unit=tax_unit_df, - ), - ) - - -def create_policy_like_modal_app(model_version) -> PEPolicy: - """Create a policy exactly like _get_pe_policy_us does in modal_app.py. - - This mimics the exact flow: - 1. Look up parameter by name from model_version.parameters - 2. Create PEParameterValue with the parameter, value, start_date, end_date - 3. Create PEPolicy with the parameter values - """ - param_lookup = {p.name: p for p in model_version.parameters} - - # This is exactly what _get_pe_policy_us does - pe_param = param_lookup.get("gov.irs.credits.ctc.refundable.fully_refundable") - if not pe_param: - raise ValueError("Parameter not found!") - - pe_pv = PEParameterValue( - parameter=pe_param, - value=True, # Make CTC fully refundable - start_date=datetime(2024, 1, 1), - end_date=None, - ) - - return PEPolicy( - name="CTC Fully Refundable", - description="Makes CTC fully refundable", - parameter_values=[pe_pv], - ) - - -def run_economy_simulation(dataset: PolicyEngineUSDataset, policy: PEPolicy | None, label: str) -> dict: - """Run an economy simulation exactly like modal_app.py does. - - This follows the exact flow from simulate_economy_us: - 1. Create PESimulation with dataset, model version, policy, dynamic - 2. Call pe_sim.ensure() (which calls run() internally) - 3. Access output via pe_sim.output_dataset - """ - print(f"\n=== {label} ===") - print(f" Policy: {policy.name if policy else 'None (baseline)'}") - if policy: - print(f" Policy parameter_values: {len(policy.parameter_values)}") - for pv in policy.parameter_values: - print(f" - {pv.parameter.name}: {pv.value} (start: {pv.start_date})") - - pe_model_version = us_latest - - # Create and run simulation - EXACTLY like modal_app.py lines 1006-1012 - pe_sim = PESimulation( - dataset=dataset, - tax_benefit_model_version=pe_model_version, - policy=policy, - dynamic=None, - ) - pe_sim.ensure() - - # Extract results from output dataset - output_data = pe_sim.output_dataset.data - - # Sum up key metrics across all tax units (weighted) - tax_unit_df = pd.DataFrame(output_data.tax_unit) - - # Get the variables we care about - total_ctc = 0 - total_income_tax = 0 - total_eitc = 0 - - for var in ["ctc", "income_tax", "eitc"]: - if var in tax_unit_df.columns: - # Weighted sum - weights = tax_unit_df.get("tax_unit_weight", pd.Series([1.0] * len(tax_unit_df))) - if var == "ctc": - total_ctc = (tax_unit_df[var] * weights).sum() - elif var == "income_tax": - total_income_tax = (tax_unit_df[var] * weights).sum() - elif var == "eitc": - total_eitc = (tax_unit_df[var] * weights).sum() - - print(f" Results (weighted totals across {len(tax_unit_df)} tax units):") - print(f" Total CTC: ${total_ctc:,.0f}") - print(f" Total Income Tax: ${total_income_tax:,.0f}") - print(f" Total EITC: ${total_eitc:,.0f}") - - # Also show per-household breakdown - print(f" Per tax unit breakdown:") - for i in range(len(tax_unit_df)): - ctc = tax_unit_df["ctc"].iloc[i] if "ctc" in tax_unit_df.columns else 0 - income_tax = tax_unit_df["income_tax"].iloc[i] if "income_tax" in tax_unit_df.columns else 0 - print(f" Tax Unit {i}: CTC=${ctc:,.0f}, Income Tax=${income_tax:,.0f}") - - return { - "total_ctc": total_ctc, - "total_income_tax": total_income_tax, - "total_eitc": total_eitc, - "tax_unit_df": tax_unit_df, - } - - -def main(): - print("=" * 60) - print("ECONOMY-WIDE SIMULATION TEST") - print("Following the exact code path from modal_app.py") - print("=" * 60) - - year = 2024 - - # Create test dataset (same for both simulations) - print("\nCreating test dataset...") - - # Run baseline simulation - baseline_dataset = create_test_dataset(year) - baseline_results = run_economy_simulation(baseline_dataset, None, "BASELINE (no policy)") - - # Create policy exactly like modal_app.py does - policy = create_policy_like_modal_app(us_latest) - - # Run reform simulation - reform_dataset = create_test_dataset(year) - reform_results = run_economy_simulation(reform_dataset, policy, "REFORM (CTC fully refundable)") - - # Compare results - print("\n" + "=" * 60) - print("COMPARISON") - print("=" * 60) - - ctc_diff = reform_results["total_ctc"] - baseline_results["total_ctc"] - tax_diff = reform_results["total_income_tax"] - baseline_results["total_income_tax"] - - print(f"\nTotal CTC:") - print(f" Baseline: ${baseline_results['total_ctc']:,.0f}") - print(f" Reform: ${reform_results['total_ctc']:,.0f}") - print(f" Change: ${ctc_diff:,.0f}") - - print(f"\nTotal Income Tax:") - print(f" Baseline: ${baseline_results['total_income_tax']:,.0f}") - print(f" Reform: ${reform_results['total_income_tax']:,.0f}") - print(f" Change: ${tax_diff:,.0f}") - - # Verdict - print("\n" + "=" * 60) - print("VERDICT") - print("=" * 60) - - if baseline_results["total_income_tax"] == reform_results["total_income_tax"]: - print("\nāŒ BUG CONFIRMED: Results are IDENTICAL!") - print(" The policy reform is NOT being applied to economy simulations.") - else: - print("\nāœ“ NO BUG: Results differ as expected!") - print(f" The fully refundable CTC reform changed income tax by ${tax_diff:,.0f}") - - -if __name__ == "__main__": - main() diff --git a/scripts/test_household_impact.py b/scripts/test_household_impact.py deleted file mode 100644 index 81c85b0..0000000 --- a/scripts/test_household_impact.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Test household impact analysis end-to-end. - -This script tests the async household impact analysis workflow: -1. Create a stored household -2. Run household impact analysis (returns immediately with report_id) -3. Poll until completed -4. Verify results - -Usage: - uv run python scripts/test_household_impact.py -""" - -import sys -import time - -import requests - -BASE_URL = "http://127.0.0.1:8000" - - -def main(): - print("=" * 60) - print("Testing Household Impact Analysis (Async)") - print("=" * 60) - - # Step 1: Create a US household - print("\n1. Creating US household...") - household_data = { - "tax_benefit_model_name": "policyengine_us", - "year": 2024, - "label": "Test household for impact analysis", - "people": [ - { - "age": 35, - "employment_income": 50000, - } - ], - "tax_unit": {}, - "family": {}, - "spm_unit": {}, - "marital_unit": {}, - "household": {"state_code": "NV"}, - } - - resp = requests.post(f"{BASE_URL}/households/", json=household_data) - if resp.status_code != 201: - print(f" FAILED: {resp.status_code} - {resp.text}") - sys.exit(1) - - household = resp.json() - household_id = household["id"] - print(f" Created household: {household_id}") - - # Step 2: Run household impact analysis - print("\n2. Starting household impact analysis...") - impact_request = { - "household_id": household_id, - "policy_id": None, # Single run under current law - } - - resp = requests.post(f"{BASE_URL}/analysis/household-impact", json=impact_request) - if resp.status_code != 200: - print(f" FAILED: {resp.status_code} - {resp.text}") - sys.exit(1) - - result = resp.json() - report_id = result["report_id"] - status = result["status"] - print(f" Report ID: {report_id}") - print(f" Initial status: {status}") - - # Step 3: Poll until completed - print("\n3. Polling for results...") - max_attempts = 30 - for attempt in range(max_attempts): - resp = requests.get(f"{BASE_URL}/analysis/household-impact/{report_id}") - if resp.status_code != 200: - print(f" FAILED: {resp.status_code} - {resp.text}") - sys.exit(1) - - result = resp.json() - status = result["status"].upper() # Normalize to uppercase - print(f" Attempt {attempt + 1}: status={status}") - - if status == "COMPLETED": - break - elif status == "FAILED": - print(f" FAILED: {result.get('error_message', 'Unknown error')}") - sys.exit(1) - - time.sleep(0.5) - else: - print(f" FAILED: Timed out after {max_attempts} attempts") - sys.exit(1) - - # Step 4: Verify results - print("\n4. Verifying results...") - baseline_result = result.get("baseline_result") - if not baseline_result: - print(" FAILED: No baseline result") - sys.exit(1) - - print(f" Baseline result keys: {list(baseline_result.keys())}") - - # Check for expected entity types - expected_entities = ["person", "tax_unit", "spm_unit", "family", "marital_unit", "household"] - for entity in expected_entities: - if entity in baseline_result: - print(f" āœ“ {entity}: {len(baseline_result[entity])} entities") - else: - print(f" āœ— {entity}: missing") - - # Look for net_income in person output - if "person" in baseline_result and baseline_result["person"]: - person = baseline_result["person"][0] - if "household_net_income" in person: - print(f" household_net_income: ${person['household_net_income']:,.2f}") - elif "spm_unit_net_income" in person: - print(f" spm_unit_net_income: ${person['spm_unit_net_income']:,.2f}") - - # Step 5: Cleanup - delete household - print("\n5. Cleaning up...") - resp = requests.delete(f"{BASE_URL}/households/{household_id}") - if resp.status_code == 204: - print(f" Deleted household: {household_id}") - else: - print(f" Warning: Failed to delete household: {resp.status_code}") - - print("\n" + "=" * 60) - print("SUCCESS: Household impact analysis working correctly!") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/scripts/test_household_scenarios.py b/scripts/test_household_scenarios.py deleted file mode 100644 index fb418a4..0000000 --- a/scripts/test_household_scenarios.py +++ /dev/null @@ -1,344 +0,0 @@ -"""Test household calculation scenarios. - -Tests: -1. US California household under current law -2. Scotland household under current law -3. US household: current law vs CTC fully refundable reform -""" - -import sys -import time -import requests - -BASE_URL = "http://127.0.0.1:8000" - - -def poll_for_completion(report_id: str, max_attempts: int = 60) -> dict: - """Poll until report is completed or failed.""" - for attempt in range(max_attempts): - resp = requests.get(f"{BASE_URL}/analysis/household-impact/{report_id}") - if resp.status_code != 200: - raise Exception(f"Failed to get report: {resp.status_code} - {resp.text}") - - result = resp.json() - status = result["status"].upper() - - if status == "COMPLETED": - return result - elif status == "FAILED": - raise Exception(f"Report failed: {result.get('error_message', 'Unknown error')}") - - time.sleep(0.5) - - raise Exception(f"Timed out after {max_attempts} attempts") - - -def print_household_summary(result: dict, label: str): - """Print summary of household calculation result.""" - print(f"\n {label}:") - - baseline = result.get("baseline_result", {}) - reform = result.get("reform_result", {}) - - # Get key metrics from person/household - if "person" in baseline and baseline["person"]: - person = baseline["person"][0] - if "household_net_income" in person: - baseline_income = person["household_net_income"] - print(f" Baseline net income: ${baseline_income:,.2f}") - - if reform and "person" in reform and reform["person"]: - reform_income = reform["person"][0].get("household_net_income", 0) - print(f" Reform net income: ${reform_income:,.2f}") - print(f" Difference: ${reform_income - baseline_income:,.2f}") - - # Show some tax/benefit info if available - for key in ["income_tax", "federal_income_tax", "state_income_tax", "ctc", "refundable_ctc"]: - if key in person: - print(f" {key}: ${person[key]:,.2f}") - - -def test_us_california(): - """Test 1: US California household under current law.""" - print("\n" + "=" * 60) - print("TEST 1: US California Household - Current Law") - print("=" * 60) - - # Create California household - household_data = { - "tax_benefit_model_name": "policyengine_us", - "year": 2024, - "label": "California test household", - "people": [ - {"age": 35, "employment_income": 75000}, - {"age": 33, "employment_income": 45000}, - {"age": 8}, # Child - ], - "tax_unit": {}, - "family": {}, - "spm_unit": {}, - "marital_unit": {}, - "household": {"state_code": "CA"}, - } - - print("\n Creating household...") - resp = requests.post(f"{BASE_URL}/households/", json=household_data) - if resp.status_code != 201: - print(f" FAILED: {resp.status_code} - {resp.text}") - return None - - household = resp.json() - household_id = household["id"] - print(f" Household ID: {household_id}") - - # Run analysis under current law (no policy_id) - print(" Running analysis...") - resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ - "household_id": household_id, - "policy_id": None, - }) - - if resp.status_code != 200: - print(f" FAILED: {resp.status_code} - {resp.text}") - return household_id - - report_id = resp.json()["report_id"] - print(f" Report ID: {report_id}") - - # Poll for results - try: - result = poll_for_completion(report_id) - print(" Status: COMPLETED") - print_household_summary(result, "Results") - except Exception as e: - print(f" FAILED: {e}") - - return household_id - - -def test_scotland(): - """Test 2: Scotland household under current law.""" - print("\n" + "=" * 60) - print("TEST 2: Scotland Household - Current Law") - print("=" * 60) - - # Create Scotland household - household_data = { - "tax_benefit_model_name": "policyengine_uk", - "year": 2024, - "label": "Scotland test household", - "people": [ - {"age": 40, "employment_income": 45000}, - ], - "benunit": {}, - "household": {"region": "SCOTLAND"}, - } - - print("\n Creating household...") - resp = requests.post(f"{BASE_URL}/households/", json=household_data) - if resp.status_code != 201: - print(f" FAILED: {resp.status_code} - {resp.text}") - return None - - household = resp.json() - household_id = household["id"] - print(f" Household ID: {household_id}") - - # Run analysis under current law - print(" Running analysis...") - resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ - "household_id": household_id, - "policy_id": None, - }) - - if resp.status_code != 200: - print(f" FAILED: {resp.status_code} - {resp.text}") - return household_id - - report_id = resp.json()["report_id"] - print(f" Report ID: {report_id}") - - # Poll for results - try: - result = poll_for_completion(report_id) - print(" Status: COMPLETED") - print_household_summary(result, "Results") - except Exception as e: - print(f" FAILED: {e}") - - return household_id - - -def test_us_ctc_reform(): - """Test 3: US household - current law vs CTC fully refundable.""" - print("\n" + "=" * 60) - print("TEST 3: US Household - Current Law vs CTC Fully Refundable") - print("=" * 60) - - # First, find the CTC refundability parameter - print("\n Finding CTC refundability parameter...") - resp = requests.get(f"{BASE_URL}/parameters", params={"search": "ctc", "limit": 50}) - if resp.status_code != 200: - print(f" FAILED to search parameters: {resp.status_code}") - return None, None - - params = resp.json() - ctc_param = None - for p in params: - # Look for the refundable portion parameter - if "refundable" in p["name"].lower() and "ctc" in p["name"].lower(): - print(f" Found: {p['name']} (label: {p.get('label')})") - ctc_param = p - break - - if not ctc_param: - # Try searching for child tax credit parameters - print(" Searching for child_tax_credit parameters...") - resp = requests.get(f"{BASE_URL}/parameters", params={"search": "child_tax_credit", "limit": 50}) - params = resp.json() - for p in params: - print(f" - {p['name']}") - if "refundable" in p["name"].lower(): - ctc_param = p - break - - if not ctc_param: - print(" Could not find CTC refundability parameter") - print(" Continuing with household creation anyway...") - - # Create household with children (needed for CTC) - household_data = { - "tax_benefit_model_name": "policyengine_us", - "year": 2024, - "label": "CTC test household", - "people": [ - {"age": 35, "employment_income": 30000}, # Lower income to see CTC effect - {"age": 33, "employment_income": 0}, - {"age": 5}, # Child 1 - {"age": 3}, # Child 2 - ], - "tax_unit": {}, - "family": {}, - "spm_unit": {}, - "marital_unit": {}, - "household": {"state_code": "TX"}, # Texas - no state income tax - } - - print("\n Creating household...") - resp = requests.post(f"{BASE_URL}/households/", json=household_data) - if resp.status_code != 201: - print(f" FAILED: {resp.status_code} - {resp.text}") - return None, None - - household = resp.json() - household_id = household["id"] - print(f" Household ID: {household_id}") - - # Create a policy that makes CTC fully refundable - policy_id = None - if ctc_param: - print("\n Creating CTC fully refundable policy...") - policy_data = { - "name": "CTC Fully Refundable", - "description": "Makes the Child Tax Credit fully refundable", - } - resp = requests.post(f"{BASE_URL}/policies/", json=policy_data) - if resp.status_code == 201: - policy = resp.json() - policy_id = policy["id"] - print(f" Policy ID: {policy_id}") - - # Add parameter value to make CTC fully refundable - # The parameter should set refundable portion to 100% or max amount - pv_data = { - "parameter_id": ctc_param["id"], - "value_json": 1.0, # 100% refundable - "start_date": "2024-01-01T00:00:00Z", - "end_date": None, - "policy_id": policy_id, - } - resp = requests.post(f"{BASE_URL}/parameter-values/", json=pv_data) - if resp.status_code == 201: - print(" Added parameter value for full refundability") - else: - print(f" Warning: Failed to add parameter value: {resp.status_code} - {resp.text}") - else: - print(f" Warning: Failed to create policy: {resp.status_code}") - - # Run analysis with reform policy - print("\n Running analysis (baseline vs reform)...") - resp = requests.post(f"{BASE_URL}/analysis/household-impact", json={ - "household_id": household_id, - "policy_id": policy_id, - }) - - if resp.status_code != 200: - print(f" FAILED: {resp.status_code} - {resp.text}") - return household_id, policy_id - - report_id = resp.json()["report_id"] - print(f" Report ID: {report_id}") - - # Poll for results - try: - result = poll_for_completion(report_id) - print(" Status: COMPLETED") - print_household_summary(result, "Results") - except Exception as e: - print(f" FAILED: {e}") - - return household_id, policy_id - - -def main(): - print("=" * 60) - print("HOUSEHOLD CALCULATION SCENARIO TESTS") - print("=" * 60) - - # Track created resources for cleanup - households = [] - policies = [] - - # Test 1: US California - hh_id = test_us_california() - if hh_id: - households.append(hh_id) - - # Test 2: Scotland - hh_id = test_scotland() - if hh_id: - households.append(hh_id) - - # Test 3: CTC Reform - hh_id, policy_id = test_us_ctc_reform() - if hh_id: - households.append(hh_id) - if policy_id: - policies.append(policy_id) - - # Cleanup - print("\n" + "=" * 60) - print("CLEANUP") - print("=" * 60) - - for hh_id in households: - resp = requests.delete(f"{BASE_URL}/households/{hh_id}") - if resp.status_code == 204: - print(f" Deleted household: {hh_id}") - else: - print(f" Warning: Failed to delete household {hh_id}: {resp.status_code}") - - for policy_id in policies: - resp = requests.delete(f"{BASE_URL}/policies/{policy_id}") - if resp.status_code == 204: - print(f" Deleted policy: {policy_id}") - else: - print(f" Warning: Failed to delete policy {policy_id}: {resp.status_code}") - - print("\n" + "=" * 60) - print("TESTS COMPLETE") - print("=" * 60) - - -if __name__ == "__main__": - main() diff --git a/supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql b/supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql deleted file mode 100644 index c1713d5..0000000 --- a/supabase/migrations_archived/20251229000000_add_parameter_values_indexes.sql +++ /dev/null @@ -1,16 +0,0 @@ --- Add indexes to parameter_values table for query optimization --- This migration improves query performance for filtering by parameter_id and policy_id - --- Composite index for the most common query pattern (filtering by both) -CREATE INDEX IF NOT EXISTS idx_parameter_values_parameter_policy -ON parameter_values(parameter_id, policy_id); - --- Single index on policy_id for filtering by policy alone -CREATE INDEX IF NOT EXISTS idx_parameter_values_policy -ON parameter_values(policy_id); - --- Partial index for baseline values (policy_id IS NULL) --- This optimizes the common "get current law values" query -CREATE INDEX IF NOT EXISTS idx_parameter_values_baseline -ON parameter_values(parameter_id) -WHERE policy_id IS NULL; diff --git a/supabase/migrations_archived/20260103000000_add_poverty_inequality.sql b/supabase/migrations_archived/20260103000000_add_poverty_inequality.sql deleted file mode 100644 index f315d93..0000000 --- a/supabase/migrations_archived/20260103000000_add_poverty_inequality.sql +++ /dev/null @@ -1,33 +0,0 @@ --- Add poverty and inequality tables for economic analysis - -CREATE TABLE IF NOT EXISTS poverty ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, - report_id UUID REFERENCES reports(id) ON DELETE CASCADE, - poverty_type VARCHAR NOT NULL, - entity VARCHAR NOT NULL DEFAULT 'person', - filter_variable VARCHAR, - headcount FLOAT, - total_population FLOAT, - rate FLOAT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE IF NOT EXISTS inequality ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, - report_id UUID REFERENCES reports(id) ON DELETE CASCADE, - income_variable VARCHAR NOT NULL, - entity VARCHAR NOT NULL DEFAULT 'household', - gini FLOAT, - top_10_share FLOAT, - top_1_share FLOAT, - bottom_50_share FLOAT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- Indexes for efficient querying -CREATE INDEX IF NOT EXISTS idx_poverty_simulation_id ON poverty(simulation_id); -CREATE INDEX IF NOT EXISTS idx_poverty_report_id ON poverty(report_id); -CREATE INDEX IF NOT EXISTS idx_inequality_simulation_id ON inequality(simulation_id); -CREATE INDEX IF NOT EXISTS idx_inequality_report_id ON inequality(report_id); diff --git a/supabase/migrations_archived/20260111000000_add_aggregate_status.sql b/supabase/migrations_archived/20260111000000_add_aggregate_status.sql deleted file mode 100644 index b190620..0000000 --- a/supabase/migrations_archived/20260111000000_add_aggregate_status.sql +++ /dev/null @@ -1,13 +0,0 @@ --- Add status and error_message columns to aggregates table -ALTER TABLE aggregates -ADD COLUMN IF NOT EXISTS status VARCHAR(20) DEFAULT 'pending', -ADD COLUMN IF NOT EXISTS error_message TEXT; - --- Add status and error_message columns to change_aggregates table -ALTER TABLE change_aggregates -ADD COLUMN IF NOT EXISTS status VARCHAR(20) DEFAULT 'pending', -ADD COLUMN IF NOT EXISTS error_message TEXT; - --- Create indexes for status filtering -CREATE INDEX IF NOT EXISTS idx_aggregates_status ON aggregates(status); -CREATE INDEX IF NOT EXISTS idx_change_aggregates_status ON change_aggregates(status); diff --git a/supabase/migrations_archived/20260203000000_create_households.sql b/supabase/migrations_archived/20260203000000_create_households.sql deleted file mode 100644 index cc1907f..0000000 --- a/supabase/migrations_archived/20260203000000_create_households.sql +++ /dev/null @@ -1,14 +0,0 @@ --- Create stored households table for persisting household definitions. - -CREATE TABLE IF NOT EXISTS households ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - tax_benefit_model_name TEXT NOT NULL, - year INTEGER NOT NULL, - label TEXT, - household_data JSONB NOT NULL, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now() -); - -CREATE INDEX idx_households_model_name ON households (tax_benefit_model_name); -CREATE INDEX idx_households_year ON households (year); diff --git a/supabase/migrations_archived/20260203000001_create_user_household_associations.sql b/supabase/migrations_archived/20260203000001_create_user_household_associations.sql deleted file mode 100644 index 3fdcb03..0000000 --- a/supabase/migrations_archived/20260203000001_create_user_household_associations.sql +++ /dev/null @@ -1,14 +0,0 @@ --- Create user-household associations table for linking users to saved households. - -CREATE TABLE IF NOT EXISTS user_household_associations ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, - household_id UUID NOT NULL REFERENCES households(id) ON DELETE CASCADE, - country_id TEXT NOT NULL, - label TEXT, - created_at TIMESTAMPTZ NOT NULL DEFAULT now(), - updated_at TIMESTAMPTZ NOT NULL DEFAULT now() -); - -CREATE INDEX idx_user_household_assoc_user ON user_household_associations (user_id); -CREATE INDEX idx_user_household_assoc_household ON user_household_associations (household_id); diff --git a/supabase/migrations_archived/20260203000002_simulation_household_support.sql b/supabase/migrations_archived/20260203000002_simulation_household_support.sql deleted file mode 100644 index 6813f07..0000000 --- a/supabase/migrations_archived/20260203000002_simulation_household_support.sql +++ /dev/null @@ -1,16 +0,0 @@ --- Add simulation_type as TEXT (SQLModel enum maps to text) -ALTER TABLE simulations ADD COLUMN simulation_type TEXT NOT NULL DEFAULT 'economy'; - --- Make dataset_id nullable (was required) -ALTER TABLE simulations ALTER COLUMN dataset_id DROP NOT NULL; - --- Add household support columns -ALTER TABLE simulations ADD COLUMN household_id UUID REFERENCES households(id); -ALTER TABLE simulations ADD COLUMN household_result JSONB; - --- Indexes -CREATE INDEX idx_simulations_household ON simulations (household_id); -CREATE INDEX idx_simulations_type ON simulations (simulation_type); - --- Add report_type to reports -ALTER TABLE reports ADD COLUMN report_type TEXT;