diff --git a/.claude/skills/database-migrations.md b/.claude/skills/database-migrations.md new file mode 100644 index 0000000..fedbef8 --- /dev/null +++ b/.claude/skills/database-migrations.md @@ -0,0 +1,301 @@ +# Database Migration Guidelines + +## Overview + +This project uses **Alembic** for database migrations with **SQLModel** models. Alembic is the industry-standard migration tool for SQLAlchemy/SQLModel projects. + +**CRITICAL**: SQL migrations are the single source of truth for database schema. All table creation and schema changes MUST go through Alembic migrations. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ SQLModel Models (src/policyengine_api/models/) │ +│ - Define Python classes │ +│ - Used for ORM queries │ +│ - NOT the source of truth for schema │ +└─────────────────────────────────────────────────────────────┘ + │ + │ alembic revision --autogenerate + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ Alembic Migrations (alembic/versions/) │ +│ - Create/alter tables │ +│ - Add indexes, constraints │ +│ - SOURCE OF TRUTH for schema │ +└─────────────────────────────────────────────────────────────┘ + │ + │ alembic upgrade head + ▼ +┌─────────────────────────────────────────────────────────────┐ +│ PostgreSQL Database (Supabase) │ +│ - Actual schema │ +│ - Tracked by alembic_version table │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Essential Rules + +### 1. NEVER use SQLModel.metadata.create_all() for schema creation + +The old pattern of using `SQLModel.metadata.create_all()` is deprecated. All tables are created via Alembic migrations. + +### 2. Every schema change requires a migration + +When you modify a SQLModel model (add column, change type, add index), you MUST: +1. Update the model in `src/policyengine_api/models/` +2. Generate a migration: `uv run alembic revision --autogenerate -m "Description"` +3. **Read and verify the generated migration** (see below) +4. Apply it: `uv run alembic upgrade head` + +### 3. ALWAYS verify auto-generated migrations before applying + +**This is critical for AI agents.** After running `alembic revision --autogenerate`, you MUST: + +1. **Read the generated migration file** in `alembic/versions/` +2. **Verify the `upgrade()` function** contains the expected changes: + - Correct table/column names + - Correct column types (e.g., `sa.String()`, `sa.Uuid()`, `sa.Integer()`) + - Proper foreign key references + - Appropriate nullable settings +3. **Verify the `downgrade()` function** properly reverses the changes +4. **Check for Alembic autogenerate limitations:** + - It may miss renamed columns (shows as drop + add instead) + - It may not detect some index changes + - It doesn't handle data migrations +5. **Edit the migration if needed** before applying + +Example verification: +```python +# Generated migration - verify this looks correct: +def upgrade() -> None: + op.add_column('users', sa.Column('phone', sa.String(), nullable=True)) + +def downgrade() -> None: + op.drop_column('users', 'phone') +``` + +**Never blindly apply a migration without reading it first.** + +### 4. Migrations must be self-contained + +Each migration should: +- Create tables it needs (never assume they exist from Python) +- Include both `upgrade()` and `downgrade()` functions +- Be idempotent where possible (use `IF NOT EXISTS` patterns) + +### 5. Never use conditional logic based on table existence + +Migrations should NOT check if tables exist. Instead: +- Ensure migrations run in the correct order (use `down_revision`) +- The initial migration creates all base tables +- Subsequent migrations build on that foundation + +## Common Commands + +```bash +# Apply all pending migrations +uv run alembic upgrade head + +# Generate migration from model changes +uv run alembic revision --autogenerate -m "Add users email index" + +# Create empty migration (for manual SQL) +uv run alembic revision -m "Add custom index" + +# Check current migration state +uv run alembic current + +# Show migration history +uv run alembic history + +# Downgrade one revision +uv run alembic downgrade -1 + +# Downgrade to specific revision +uv run alembic downgrade +``` + +## Local Development Workflow + +```bash +# 1. Start Supabase +supabase start + +# 2. Initialize database (runs migrations + applies RLS policies) +uv run python scripts/init.py + +# 3. Seed data +uv run python scripts/seed.py +``` + +### Reset database (DESTRUCTIVE) + +```bash +uv run python scripts/init.py --reset +``` + +## Adding a New Model + +1. Create the model in `src/policyengine_api/models/` + +```python +# src/policyengine_api/models/my_model.py +from sqlmodel import SQLModel, Field +from uuid import UUID, uuid4 + +class MyModel(SQLModel, table=True): + __tablename__ = "my_models" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + name: str +``` + +2. Export in `__init__.py`: + +```python +# src/policyengine_api/models/__init__.py +from .my_model import MyModel +``` + +3. Generate migration: + +```bash +uv run alembic revision --autogenerate -m "Add my_models table" +``` + +4. Review the generated migration in `alembic/versions/` + +5. Apply the migration: + +```bash +uv run alembic upgrade head +``` + +6. Update `scripts/init.py` to include the table in RLS policies if needed. + +## Adding an Index + +1. Generate a migration: + +```bash +uv run alembic revision -m "Add index on users.email" +``` + +2. Edit the migration: + +```python +def upgrade() -> None: + op.create_index("idx_users_email", "users", ["email"]) + +def downgrade() -> None: + op.drop_index("idx_users_email", "users") +``` + +3. Apply: + +```bash +uv run alembic upgrade head +``` + +## Production Considerations + +### Applying migrations to production + +1. Migrations are automatically applied when deploying +2. Always test migrations locally first +3. For data migrations, consider running during low-traffic periods + +### Transitioning production from old system to Alembic + +Production databases that were created before Alembic (using the old `SQLModel.metadata.create_all()` approach or raw Supabase migrations) need special handling. Running `alembic upgrade head` would fail because the tables already exist. + +**The solution: `alembic stamp`** + +The `alembic stamp` command marks a migration as "already applied" without actually running it. This tells Alembic "the database is already at this state, start tracking from here." + +**How it works:** + +1. `alembic stamp ` inserts a row into the `alembic_version` table with the specified revision ID +2. Alembic now thinks that migration (and all migrations before it) have been applied +3. Future migrations will run normally starting from that point + +**Step-by-step production transition:** + +```bash +# 1. Connect to production database +# (set SUPABASE_DB_URL or other connection env vars) + +# 2. Check if alembic_version table exists +# If not, Alembic will create it automatically + +# 3. Verify production schema matches the initial migration +# Compare tables/columns in production against alembic/versions/20260204_d6e30d3b834d_initial_schema.py + +# 4. Stamp the initial migration as applied +uv run alembic stamp d6e30d3b834d + +# 5. If production also has the indexes from the second migration, stamp that too +uv run alembic stamp a17ac554f4aa + +# 6. Verify the stamp worked +uv run alembic current +# Should show: a17ac554f4aa (head) + +# 7. From now on, new migrations will apply normally +uv run alembic upgrade head +``` + +**Handling partially applied migrations:** + +If production has some but not all changes from a migration: + +1. Manually apply the missing changes via SQL +2. Then stamp that migration as complete +3. Or: create a new migration that only adds the missing pieces + +**After stamping:** + +- All future schema changes go through Alembic migrations +- Developers generate migrations with `alembic revision --autogenerate` +- Deployments run `alembic upgrade head` to apply pending migrations +- The `alembic_version` table tracks what's been applied + +## File Structure + +``` +alembic/ +├── env.py # Alembic configuration (imports models, sets DB URL) +├── script.py.mako # Template for new migrations +├── versions/ # Migration files +│ ├── 20260204_d6e30d3b834d_initial_schema.py +│ └── 20260204_a17ac554f4aa_add_parameter_values_indexes.py +alembic.ini # Alembic settings + +supabase/ +├── migrations/ # Supabase-specific migrations (storage only) +│ ├── 20241119000000_storage_bucket.sql +│ └── 20241121000000_storage_policies.sql +└── migrations_archived/ # Old table migrations (now in Alembic) +``` + +## Troubleshooting + +### "Target database is not up to date" + +Run `alembic upgrade head` to apply pending migrations. + +### "Can't locate revision" + +The alembic_version table has a revision that doesn't exist in your migrations folder. This can happen if someone deleted a migration file. Fix by stamping to a known revision: + +```bash +alembic stamp head # If tables are current +alembic stamp d6e30d3b834d # If at initial schema +``` + +### "Table already exists" + +The migration is trying to create a table that already exists. Options: +1. If this is a fresh setup, drop and recreate: `uv run python scripts/init.py --reset` +2. If in production, stamp the migration as applied: `alembic stamp ` diff --git a/CLAUDE.md b/CLAUDE.md index 2df55fc..d6fb240 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,7 +75,21 @@ Use `gh` CLI for GitHub operations to ensure Actions run correctly. ## Database -`make init` resets tables and storage. `make seed` populates UK/US models with variables, parameters, and datasets. +This project uses **Alembic** for database migrations. See `.claude/skills/database-migrations.md` for detailed guidelines. + +**Key rules:** +- All schema changes go through Alembic migrations (never use `SQLModel.metadata.create_all()`) +- After modifying a model: `uv run alembic revision --autogenerate -m "Description"` +- Apply migrations: `uv run alembic upgrade head` + +**Local development:** +```bash +supabase start # Start local Supabase +uv run python scripts/init.py # Run migrations + apply RLS policies +uv run python scripts/seed.py # Seed data +``` + +`scripts/init.py --reset` drops and recreates everything (destructive). ## Modal sandbox + Claude Code CLI gotchas diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..ed54635 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,145 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = %(here)s/alembic + +# template used to generate migration file names +# Prepend with date for easier chronological ordering +file_template = %%(year)d%%(month).2d%%(day).2d_%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. for multiple paths, the path separator +# is defined by "path_separator" below. +prepend_sys_path = . + + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "path_separator" +# below. +# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions + +# path_separator; This indicates what character is used to split lists of file +# paths, including version_locations and prepend_sys_path within configparser +# files such as alembic.ini. +# The default rendered in new alembic.ini files is "os", which uses os.pathsep +# to provide os-dependent path splitting. +# +# Note that in order to support legacy alembic.ini files, this default does NOT +# take place if path_separator is not present in alembic.ini. If this +# option is omitted entirely, fallback logic is as follows: +# +# 1. Parsing of the version_locations option falls back to using the legacy +# "version_path_separator" key, which if absent then falls back to the legacy +# behavior of splitting on spaces and/or commas. +# 2. Parsing of the prepend_sys_path option falls back to the legacy +# behavior of splitting on spaces, commas, or colons. +# +# Valid values for path_separator are: +# +# path_separator = : +# path_separator = ; +# path_separator = space +# path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# database URL - This is overridden by env.py which reads from application settings. +# The placeholder below is only used if env.py doesn't set it. +sqlalchemy.url = postgresql://placeholder:placeholder@localhost/placeholder + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# NOTE: ruff is in dev dependencies, so this hook only works when dev deps are installed +# hooks = ruff +# ruff.type = module +# ruff.module = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Alternatively, use the exec runner to execute a binary found on your PATH +# hooks = ruff +# ruff.type = exec +# ruff.executable = ruff +# ruff.options = check --fix REVISION_SCRIPT_FILENAME + +# Logging configuration. This is also consumed by the user-maintained +# env.py script only. +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..f930498 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,87 @@ +"""Alembic environment configuration for SQLModel migrations. + +This module configures Alembic to: +1. Use the database URL from application settings +2. Import all SQLModel models for autogenerate support +3. Run migrations in both offline and online modes +""" + +import sys +from logging.config import fileConfig +from pathlib import Path + +from sqlalchemy import engine_from_config, pool +from sqlmodel import SQLModel + +from alembic import context + +# Add src to path so we can import policyengine_api +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Import all models to register them with SQLModel.metadata +# This is required for autogenerate to detect model changes +from policyengine_api import models # noqa: F401 +from policyengine_api.config.settings import settings + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Override sqlalchemy.url with the actual database URL from settings +config.set_main_option("sqlalchemy.url", settings.database_url) + +# Interpret the config file for Python logging. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# SQLModel metadata for autogenerate support +# This allows Alembic to detect changes in your SQLModel models +target_metadata = SQLModel.metadata + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..1101630 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/20260207_36f9d434e95b_initial_schema.py b/alembic/versions/20260207_36f9d434e95b_initial_schema.py new file mode 100644 index 0000000..0dce2e7 --- /dev/null +++ b/alembic/versions/20260207_36f9d434e95b_initial_schema.py @@ -0,0 +1,321 @@ +"""initial schema + +Revision ID: 36f9d434e95b +Revises: +Create Date: 2026-02-07 01:52:16.497121 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes + + +# revision identifiers, used by Alembic. +revision: str = '36f9d434e95b' +down_revision: Union[str, Sequence[str], None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('dynamics', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('policies', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('tax_benefit_models', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('users', + sa.Column('first_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('last_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('email', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True) + op.create_table('datasets', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('filepath', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('year', sa.Integer(), nullable=False), + sa.Column('is_output_dataset', sa.Boolean(), nullable=False), + sa.Column('tax_benefit_model_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_id'], ['tax_benefit_models.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('household_jobs', + sa.Column('tax_benefit_model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('request_data', sa.JSON(), nullable=True), + sa.Column('policy_id', sa.Uuid(), nullable=True), + sa.Column('dynamic_id', sa.Uuid(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='householdjobstatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('result', sa.JSON(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('started_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['dynamic_id'], ['dynamics.id'], ), + sa.ForeignKeyConstraint(['policy_id'], ['policies.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('tax_benefit_model_versions', + sa.Column('model_id', sa.Uuid(), nullable=False), + sa.Column('version', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['model_id'], ['tax_benefit_models.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('dataset_versions', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('dataset_id', sa.Uuid(), nullable=False), + sa.Column('tax_benefit_model_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['tax_benefit_model_id'], ['tax_benefit_models.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('parameters', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('data_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('unit', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('simulations', + sa.Column('dataset_id', sa.Uuid(), nullable=False), + sa.Column('policy_id', sa.Uuid(), nullable=True), + sa.Column('dynamic_id', sa.Uuid(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('output_dataset_id', sa.Uuid(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='simulationstatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.Column('started_at', sa.DateTime(), nullable=True), + sa.Column('completed_at', sa.DateTime(), nullable=True), + sa.ForeignKeyConstraint(['dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['dynamic_id'], ['dynamics.id'], ), + sa.ForeignKeyConstraint(['output_dataset_id'], ['datasets.id'], ), + sa.ForeignKeyConstraint(['policy_id'], ['policies.id'], ), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('variables', + sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('data_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('possible_values', sa.JSON(), nullable=True), + sa.Column('tax_benefit_model_version_id', sa.Uuid(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['tax_benefit_model_version_id'], ['tax_benefit_model_versions.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('parameter_values', + sa.Column('parameter_id', sa.Uuid(), nullable=False), + sa.Column('value_json', sa.JSON(), nullable=True), + sa.Column('start_date', sa.DateTime(), nullable=False), + sa.Column('end_date', sa.DateTime(), nullable=True), + sa.Column('policy_id', sa.Uuid(), nullable=True), + sa.Column('dynamic_id', sa.Uuid(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['dynamic_id'], ['dynamics.id'], ), + sa.ForeignKeyConstraint(['parameter_id'], ['parameters.id'], ), + sa.ForeignKeyConstraint(['policy_id'], ['policies.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('reports', + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('description', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('user_id', sa.Uuid(), nullable=True), + sa.Column('markdown', sa.Text(), nullable=True), + sa.Column('parent_report_id', sa.Uuid(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='reportstatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=True), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['parent_report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('aggregates', + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=True), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('aggregate_type', sa.Enum('SUM', 'MEAN', 'COUNT', name='aggregatetype'), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('filter_config', sa.JSON(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='aggregatestatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('result', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('change_aggregates', + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=False), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=True), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('aggregate_type', sa.Enum('SUM', 'MEAN', 'COUNT', name='changeaggregatetype'), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('filter_config', sa.JSON(), nullable=True), + sa.Column('change_geq', sa.Float(), nullable=True), + sa.Column('change_leq', sa.Float(), nullable=True), + sa.Column('status', sa.Enum('PENDING', 'RUNNING', 'COMPLETED', 'FAILED', name='changeaggregatestatus'), nullable=False), + sa.Column('error_message', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('result', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('decile_impacts', + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=False), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('income_variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('decile', sa.Integer(), nullable=False), + sa.Column('quantiles', sa.Integer(), nullable=False), + sa.Column('baseline_mean', sa.Float(), nullable=True), + sa.Column('reform_mean', sa.Float(), nullable=True), + sa.Column('absolute_change', sa.Float(), nullable=True), + sa.Column('relative_change', sa.Float(), nullable=True), + sa.Column('count_better_off', sa.Float(), nullable=True), + sa.Column('count_worse_off', sa.Float(), nullable=True), + sa.Column('count_no_change', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('inequality', + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('income_variable', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('gini', sa.Float(), nullable=True), + sa.Column('top_10_share', sa.Float(), nullable=True), + sa.Column('top_1_share', sa.Float(), nullable=True), + sa.Column('bottom_50_share', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('poverty', + sa.Column('simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('poverty_type', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('filter_variable', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('headcount', sa.Float(), nullable=True), + sa.Column('total_population', sa.Float(), nullable=True), + sa.Column('rate', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.ForeignKeyConstraint(['simulation_id'], ['simulations.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('program_statistics', + sa.Column('baseline_simulation_id', sa.Uuid(), nullable=False), + sa.Column('reform_simulation_id', sa.Uuid(), nullable=False), + sa.Column('report_id', sa.Uuid(), nullable=True), + sa.Column('program_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('entity', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('is_tax', sa.Boolean(), nullable=False), + sa.Column('baseline_total', sa.Float(), nullable=True), + sa.Column('reform_total', sa.Float(), nullable=True), + sa.Column('change', sa.Float(), nullable=True), + sa.Column('baseline_count', sa.Float(), nullable=True), + sa.Column('reform_count', sa.Float(), nullable=True), + sa.Column('winners', sa.Float(), nullable=True), + sa.Column('losers', sa.Float(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['baseline_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['reform_simulation_id'], ['simulations.id'], ), + sa.ForeignKeyConstraint(['report_id'], ['reports.id'], ), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('program_statistics') + op.drop_table('poverty') + op.drop_table('inequality') + op.drop_table('decile_impacts') + op.drop_table('change_aggregates') + op.drop_table('aggregates') + op.drop_table('reports') + op.drop_table('parameter_values') + op.drop_table('variables') + op.drop_table('simulations') + op.drop_table('parameters') + op.drop_table('dataset_versions') + op.drop_table('tax_benefit_model_versions') + op.drop_table('household_jobs') + op.drop_table('datasets') + op.drop_index(op.f('ix_users_email'), table_name='users') + op.drop_table('users') + op.drop_table('tax_benefit_models') + op.drop_table('policies') + op.drop_table('dynamics') + # ### end Alembic commands ### diff --git a/alembic/versions/20260207_f419b5f4acba_add_household_support.py b/alembic/versions/20260207_f419b5f4acba_add_household_support.py new file mode 100644 index 0000000..cef65f3 --- /dev/null +++ b/alembic/versions/20260207_f419b5f4acba_add_household_support.py @@ -0,0 +1,81 @@ +"""add household support + +Revision ID: f419b5f4acba +Revises: 36f9d434e95b +Create Date: 2026-02-07 01:56:31.064511 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +import sqlmodel.sql.sqltypes +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'f419b5f4acba' +down_revision: Union[str, Sequence[str], None] = '36f9d434e95b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('households', + sa.Column('tax_benefit_model_name', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('year', sa.Integer(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('household_data', sa.JSON(), nullable=False), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('user_household_associations', + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('household_id', sa.Uuid(), nullable=False), + sa.Column('country_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column('label', sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('created_at', sa.DateTime(), nullable=False), + sa.Column('updated_at', sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint(['household_id'], ['households.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_user_household_associations_household_id'), 'user_household_associations', ['household_id'], unique=False) + op.create_index(op.f('ix_user_household_associations_user_id'), 'user_household_associations', ['user_id'], unique=False) + op.add_column('reports', sa.Column('report_type', sqlmodel.sql.sqltypes.AutoString(), nullable=True)) + # Create enum type first + simulationtype = postgresql.ENUM('HOUSEHOLD', 'ECONOMY', name='simulationtype', create_type=False) + simulationtype.create(op.get_bind(), checkfirst=True) + op.add_column('simulations', sa.Column('simulation_type', sa.Enum('HOUSEHOLD', 'ECONOMY', name='simulationtype', create_type=False), nullable=False)) + op.add_column('simulations', sa.Column('household_id', sa.Uuid(), nullable=True)) + op.add_column('simulations', sa.Column('household_result', postgresql.JSON(astext_type=sa.Text()), nullable=True)) + op.alter_column('simulations', 'dataset_id', + existing_type=sa.UUID(), + nullable=True) + op.create_foreign_key(None, 'simulations', 'households', ['household_id'], ['id']) + op.add_column('variables', sa.Column('default_value', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('variables', 'default_value') + op.drop_constraint(None, 'simulations', type_='foreignkey') + op.alter_column('simulations', 'dataset_id', + existing_type=sa.UUID(), + nullable=False) + op.drop_column('simulations', 'household_result') + op.drop_column('simulations', 'household_id') + op.drop_column('simulations', 'simulation_type') + # Drop enum type + postgresql.ENUM('HOUSEHOLD', 'ECONOMY', name='simulationtype').drop(op.get_bind(), checkfirst=True) + op.drop_column('reports', 'report_type') + op.drop_index(op.f('ix_user_household_associations_user_id'), table_name='user_household_associations') + op.drop_index(op.f('ix_user_household_associations_household_id'), table_name='user_household_associations') + op.drop_table('user_household_associations') + op.drop_table('households') + # ### end Alembic commands ### diff --git a/docker-compose.yml b/docker-compose.yml index 60e8645..60aa598 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -5,10 +5,10 @@ services: ports: - "${API_PORT:-8000}:${API_PORT:-8000}" environment: - SUPABASE_URL: http://supabase_kong_policyengine-api-v2:8000 + SUPABASE_URL: http://supabase_kong_policyengine-api-v2-alpha:8000 SUPABASE_KEY: ${SUPABASE_KEY} SUPABASE_SERVICE_KEY: ${SUPABASE_SERVICE_KEY} - SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2:5432/postgres + SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2-alpha:5432/postgres LOGFIRE_TOKEN: ${LOGFIRE_TOKEN} DEBUG: "false" API_PORT: ${API_PORT:-8000} @@ -19,7 +19,7 @@ services: - ./src:/app/src - ./docs/out:/app/docs/out networks: - - supabase_network_policyengine-api-v2 + - supabase_network_policyengine-api-v2-alpha healthcheck: test: ["CMD", "python", "-c", "import httpx; exit(0 if httpx.get('http://localhost:${API_PORT:-8000}/health', timeout=2).status_code == 200 else 1)"] interval: 5s @@ -31,16 +31,16 @@ services: build: . command: pytest tests/ -v environment: - SUPABASE_URL: http://supabase_kong_policyengine-api-v2:8000 + SUPABASE_URL: http://supabase_kong_policyengine-api-v2-alpha:8000 SUPABASE_KEY: ${SUPABASE_KEY} SUPABASE_SERVICE_KEY: ${SUPABASE_SERVICE_KEY} - SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2:5432/postgres + SUPABASE_DB_URL: postgresql://postgres:postgres@supabase_db_policyengine-api-v2-alpha:5432/postgres LOGFIRE_TOKEN: ${LOGFIRE_TOKEN} volumes: - ./src:/app/src - ./tests:/app/tests networks: - - supabase_network_policyengine-api-v2 + - supabase_network_policyengine-api-v2-alpha depends_on: api: condition: service_healthy @@ -48,5 +48,5 @@ services: - test networks: - supabase_network_policyengine-api-v2: + supabase_network_policyengine-api-v2-alpha: external: true diff --git a/pyproject.toml b/pyproject.toml index 27eb310..d624a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,11 @@ dependencies = [ "psycopg2-binary>=2.9.10", "supabase>=2.10.0", "storage3>=0.8.1", - "policyengine>=3.1.15", + # IMPORTANT: Before merging app-v2-migration into main, replace this git + # dependency with the production PyPI version of policyengine (e.g., "policyengine>=X.Y.Z"). + # The git ref is used here because the app-v2-migration branch contains fixes + # (US reform application, regions support) not yet released to PyPI. + "policyengine @ git+https://github.com/PolicyEngine/policyengine.py.git@app-v2-migration", "policyengine-uk>=2.0.0", "policyengine-us>=1.0.0", "pydantic>=2.9.2", @@ -24,6 +28,7 @@ dependencies = [ "fastapi-mcp>=0.4.0", "modal>=0.68.0", "anthropic>=0.40.0", + "alembic>=1.13.0", ] [project.optional-dependencies] @@ -38,6 +43,9 @@ dev = [ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [tool.hatch.build.targets.wheel] packages = ["src/policyengine_api"] diff --git a/scripts/init.py b/scripts/init.py index cf7a04a..3aa925b 100644 --- a/scripts/init.py +++ b/scripts/init.py @@ -1,12 +1,19 @@ -"""Initialise Supabase: reset database, recreate tables, buckets, and permissions. +"""Initialise Supabase database with tables, buckets, and permissions. -This script performs a complete reset of the Supabase instance: -1. Drops and recreates the public schema (all tables) -2. Deletes and recreates the storage bucket -3. Creates all tables from SQLModel definitions -4. Applies RLS policies and storage permissions +This script can run in two modes: +1. Init mode (default): Creates tables via Alembic, applies RLS policies +2. Reset mode (--reset): Drops everything and recreates from scratch (DESTRUCTIVE) + +Usage: + uv run python scripts/init.py # Safe init (creates if not exists) + uv run python scripts/init.py --reset # Destructive reset (drops everything) + +For local development after `supabase start`, use init mode. +For production, use init mode to ensure tables and policies exist. +Reset mode should only be used when you need a completely fresh database. """ +import subprocess import sys from pathlib import Path @@ -14,16 +21,14 @@ from rich.console import Console from rich.panel import Panel -from sqlmodel import SQLModel, create_engine +from sqlmodel import create_engine -# Import all models to register them with SQLModel.metadata -from policyengine_api import models # noqa: F401 from policyengine_api.config.settings import settings from policyengine_api.services.storage import get_service_role_client console = Console() -MIGRATIONS_DIR = Path(__file__).parent.parent / "supabase" / "migrations" +PROJECT_ROOT = Path(__file__).parent.parent def reset_storage_bucket(): @@ -57,30 +62,61 @@ def reset_storage_bucket(): console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") +def ensure_storage_bucket(): + """Ensure storage bucket exists (non-destructive).""" + console.print("[bold blue]Ensuring storage bucket exists...") + + try: + supabase = get_service_role_client() + bucket_name = settings.storage_bucket + + # Try to get bucket info + try: + supabase.storage.get_bucket(bucket_name) + console.print(f"[green]✓[/green] Bucket '{bucket_name}' exists") + except Exception: + # Bucket doesn't exist, create it + supabase.storage.create_bucket(bucket_name, options={"public": True}) + console.print(f"[green]✓[/green] Created bucket '{bucket_name}'") + + except Exception as e: + console.print(f"[yellow]⚠ Warning with storage bucket: {e}[/yellow]") + + def reset_database(): - """Drop and recreate all tables.""" - console.print("[bold blue]Resetting database...") + """Drop and recreate the public schema (DESTRUCTIVE).""" + console.print("[bold red]Dropping database schema...") engine = create_engine(settings.database_url, echo=False) - # Drop and recreate public schema - console.print(" Dropping public schema...") with engine.begin() as conn: conn.exec_driver_sql("DROP SCHEMA public CASCADE") conn.exec_driver_sql("CREATE SCHEMA public") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO postgres") conn.exec_driver_sql("GRANT ALL ON SCHEMA public TO public") - # Create all tables from SQLModel - console.print(" Creating tables...") - SQLModel.metadata.create_all(engine) + console.print("[green]✓[/green] Schema dropped and recreated") + return engine - tables = list(SQLModel.metadata.tables.keys()) - console.print(f"[green]✓[/green] Created {len(tables)} tables:") - for table in sorted(tables): - console.print(f" {table}") - return engine +def run_alembic_migrations(): + """Run Alembic migrations to create/update tables.""" + console.print("[bold blue]Running Alembic migrations...") + + result = subprocess.run( + ["uv", "run", "alembic", "upgrade", "head"], + cwd=PROJECT_ROOT, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + console.print(f"[red]✗ Alembic migration failed:[/red]") + console.print(result.stderr) + raise RuntimeError("Alembic migration failed") + + console.print("[green]✓[/green] Alembic migrations complete") + console.print(result.stdout) def apply_storage_policies(engine): @@ -158,6 +194,10 @@ def apply_rls_policies(engine): "parameter_values", "users", "household_jobs", + "households", + "user_household_associations", + "poverty", + "inequality", ] # Read-only tables (public can read, only service role can write) @@ -178,6 +218,7 @@ def apply_rls_policies(engine): "dynamics", "reports", "household_jobs", + "households", ] # Read-only results tables @@ -186,6 +227,8 @@ def apply_rls_policies(engine): "change_aggregates", "decile_impacts", "program_statistics", + "poverty", + "inequality", ] sql_parts = [] @@ -230,6 +273,13 @@ def apply_rls_policies(engine): FOR SELECT TO anon, authenticated USING (true); """) + # User-household associations need special handling + sql_parts.append(""" + DROP POLICY IF EXISTS "Users can manage own associations" ON user_household_associations; + CREATE POLICY "Users can manage own associations" ON user_household_associations + FOR ALL TO anon, authenticated USING (true) WITH CHECK (true); + """) + sql = "\n".join(sql_parts) conn = engine.raw_connection() @@ -246,30 +296,53 @@ def apply_rls_policies(engine): def main(): - """Run full Supabase initialisation.""" - console.print( - Panel.fit( - "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" - "This script resets the entire Supabase instance.", - title="Supabase init", + """Run Supabase initialisation.""" + reset_mode = "--reset" in sys.argv + + if reset_mode: + console.print( + Panel.fit( + "[bold red]⚠ WARNING: This will DELETE ALL DATA[/bold red]\n" + "This script will reset the entire Supabase instance.", + title="Supabase RESET", + ) ) - ) - # Confirm unless running non-interactively - if sys.stdin.isatty(): - response = console.input("\nType 'yes' to continue: ") - if response.lower() != "yes": - console.print("[yellow]Aborted[/yellow]") - return + # Confirm unless running non-interactively + if sys.stdin.isatty(): + response = console.input("\nType 'yes' to continue: ") + if response.lower() != "yes": + console.print("[yellow]Aborted[/yellow]") + return + + console.print() + + # Reset storage bucket + reset_storage_bucket() + console.print() + + # Drop database schema + engine = reset_database() + console.print() + else: + console.print( + Panel.fit( + "[bold blue]Initialising Supabase[/bold blue]\n" + "This will create tables if they don't exist (safe/idempotent).\n" + "Use [cyan]--reset[/cyan] flag to drop and recreate everything.", + title="Supabase init", + ) + ) + console.print() - console.print() + # Ensure storage bucket exists + ensure_storage_bucket() + console.print() - # Reset storage bucket - reset_storage_bucket() - console.print() + engine = create_engine(settings.database_url, echo=False) - # Reset database and create tables - engine = reset_database() + # Run Alembic migrations to create/update tables + run_alembic_migrations() console.print() # Apply storage policies diff --git a/scripts/seed.py b/scripts/seed.py index f3fbfa8..4274528 100644 --- a/scripts/seed.py +++ b/scripts/seed.py @@ -363,7 +363,7 @@ def seed_model(model_version, session, lite: bool = False) -> TaxBenefitModelVer return db_version -def seed_datasets(session, lite: bool = False): +def seed_datasets(session, lite: bool = False, skip_uk_datasets: bool = False): """Seed datasets and upload to S3.""" with logfire.span("seed_datasets"): mode_str = " (lite mode - 2026 only)" if lite else "" @@ -383,60 +383,64 @@ def seed_datasets(session, lite: bool = False): ) return - # UK datasets - console.print(" Creating UK datasets...") data_folder = str(Path(__file__).parent.parent / "data") - uk_datasets = ensure_uk_datasets(data_folder=data_folder) - - # In lite mode, only upload FRS 2026 - if lite: - uk_datasets = { - k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k - } - console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + # UK datasets uk_created = 0 uk_skipped = 0 - with logfire.span("seed_uk_datasets", count=len(uk_datasets)): - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - console=console, - ) as progress: - task = progress.add_task("UK datasets", total=len(uk_datasets)) - for _, pe_dataset in uk_datasets.items(): - progress.update(task, description=f"UK: {pe_dataset.name}") - - # Check if dataset already exists - existing = session.exec( - select(Dataset).where(Dataset.name == pe_dataset.name) - ).first() - - if existing: - uk_skipped += 1 + if skip_uk_datasets: + console.print(" [yellow]Skipping UK datasets (--skip-uk-datasets)[/yellow]") + else: + console.print(" Creating UK datasets...") + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + + # In lite mode, only upload FRS 2026 + if lite: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k + } + console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for _, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + uk_skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + uk_created += 1 progress.advance(task) - continue - - # Upload to S3 - object_name = upload_dataset_for_seeding(pe_dataset.filepath) - - # Create database record - db_dataset = Dataset( - name=pe_dataset.name, - description=pe_dataset.description, - filepath=object_name, - year=pe_dataset.year, - tax_benefit_model_id=uk_model.id, - ) - session.add(db_dataset) - session.commit() - uk_created += 1 - progress.advance(task) - console.print( - f" [green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped" - ) + console.print( + f" [green]✓[/green] UK: {uk_created} created, {uk_skipped} skipped" + ) # US datasets console.print(" Creating US datasets...") @@ -622,6 +626,11 @@ def main(): action="store_true", help="Lite mode: skip US state parameters, only seed FRS 2026 and CPS 2026 datasets", ) + parser.add_argument( + "--skip-uk-datasets", + action="store_true", + help="Skip UK datasets (useful when HuggingFace token is not available)", + ) args = parser.parse_args() with logfire.span("database_seeding"): @@ -638,7 +647,7 @@ def main(): console.print(f"[green]✓[/green] US model seeded: {us_version.id}\n") # Seed datasets - seed_datasets(session, lite=args.lite) + seed_datasets(session, lite=args.lite, skip_uk_datasets=args.skip_uk_datasets) # Seed example policies seed_example_policies(session) diff --git a/scripts/seed_common.py b/scripts/seed_common.py new file mode 100644 index 0000000..49797cb --- /dev/null +++ b/scripts/seed_common.py @@ -0,0 +1,370 @@ +"""Shared utilities for seed scripts.""" + +import io +import json +import logging +import math +import sys +import warnings +from datetime import datetime, timezone +from pathlib import Path +from uuid import uuid4 + +import logfire +from rich.console import Console +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import Session, create_engine + +# Disable all SQLAlchemy and database logging BEFORE any imports +logging.basicConfig(level=logging.ERROR) +logging.getLogger("sqlalchemy").setLevel(logging.ERROR) +warnings.filterwarnings("ignore") + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from policyengine_api.config.settings import settings # noqa: E402 + +# Configure logfire +if settings.logfire_token: + logfire.configure( + token=settings.logfire_token, + environment=settings.logfire_environment, + console=False, + ) + +console = Console() + + +def get_session(): + """Get database session with logging disabled.""" + engine = create_engine(settings.database_url, echo=False) + return Session(engine) + + +def bulk_insert(session, table: str, columns: list[str], rows: list[dict]): + """Fast bulk insert using PostgreSQL COPY via StringIO.""" + if not rows: + return + + # Get raw psycopg2 connection + connection = session.connection() + raw_conn = connection.connection.dbapi_connection + cursor = raw_conn.cursor() + + # Build CSV-like data in memory + output = io.StringIO() + for row in rows: + values = [] + for col in columns: + val = row[col] + if val is None: + values.append("\\N") + elif isinstance(val, str): + # Escape special characters for COPY + val = ( + val.replace("\\", "\\\\").replace("\t", "\\t").replace("\n", "\\n") + ) + values.append(val) + else: + values.append(str(val)) + output.write("\t".join(values) + "\n") + + output.seek(0) + + # COPY is the fastest way to bulk load PostgreSQL + cursor.copy_from(output, table, columns=columns, null="\\N") + session.commit() + + +def seed_model(model_version, session, lite: bool = False): + """Seed a tax-benefit model with its variables and parameters. + + Args: + model_version: The policyengine package model version + session: Database session + lite: If True, skip state-level parameters + + Returns the TaxBenefitModelVersion that was created or found. + """ + from policyengine_api.models import ( + TaxBenefitModel, + TaxBenefitModelVersion, + ) + from sqlmodel import select + + with logfire.span( + "seed_model", + model=model_version.model.id, + version=model_version.version, + ): + # Create or get the model + console.print(f"[bold blue]Seeding {model_version.model.id}...") + + existing_model = session.exec( + select(TaxBenefitModel).where( + TaxBenefitModel.name == model_version.model.id + ) + ).first() + + if existing_model: + db_model = existing_model + console.print(f" Using existing model: {db_model.id}") + else: + db_model = TaxBenefitModel( + name=model_version.model.id, + description=model_version.model.description, + ) + session.add(db_model) + session.commit() + session.refresh(db_model) + console.print(f" Created model: {db_model.id}") + + # Create model version + existing_version = session.exec( + select(TaxBenefitModelVersion).where( + TaxBenefitModelVersion.model_id == db_model.id, + TaxBenefitModelVersion.version == model_version.version, + ) + ).first() + + if existing_version: + console.print( + f" Model version {model_version.version} already exists, skipping" + ) + return existing_version + + db_version = TaxBenefitModelVersion( + model_id=db_model.id, + version=model_version.version, + description=f"Version {model_version.version}", + ) + session.add(db_version) + session.commit() + session.refresh(db_version) + console.print(f" Created version: {db_version.version}") + + # Add variables + with logfire.span("add_variables", count=len(model_version.variables)): + var_rows = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(model_version.variables)} variables", + total=len(model_version.variables), + ) + for var in model_version.variables: + # default_value is pre-serialized by policyengine.py: + # - Enum values are converted to their name (e.g., "SINGLE") + # - datetime.date values are converted to ISO format + # - Primitives (bool, int, float, str) are kept as-is + var_rows.append( + { + "id": uuid4(), + "name": var.name, + "entity": var.entity, + "description": var.description or "", + "data_type": var.data_type.__name__ + if hasattr(var.data_type, "__name__") + else str(var.data_type), + "possible_values": None, + "default_value": json.dumps(var.default_value), + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(var_rows)} variables...") + bulk_insert( + session, + "variables", + [ + "id", + "name", + "entity", + "description", + "data_type", + "possible_values", + "default_value", + "tax_benefit_model_version_id", + "created_at", + ], + var_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(model_version.variables)} variables" + ) + + # Add parameters - deduplicate by name (keep first occurrence) + # + # WHY DEDUPLICATION IS NEEDED: + # The policyengine package can provide multiple parameter entries with the same + # name. This happens because parameters can have multiple bracket entries or + # state-specific variants that share the same base name. We keep only the first + # occurrence to avoid database unique constraint violations and reduce redundancy. + # + # NOTE: We do NOT filter by label. Parameters without labels (bracket params, + # breakdown params) are still valid and needed for policy analysis. + # + # In lite mode, exclude US state parameters (gov.states.*) + seen_names = set() + parameters_to_add = [] + skipped_state_params = 0 + skipped_duplicate = 0 + + for p in model_version.parameters: + if p.name in seen_names: + skipped_duplicate += 1 + continue + # In lite mode, skip state-level parameters for faster seeding + if lite and p.name.startswith("gov.states."): + skipped_state_params += 1 + continue + parameters_to_add.append(p) + seen_names.add(p.name) + + console.print(f" Parameter filtering:") + console.print(f" - Total from source: {len(model_version.parameters)}") + console.print(f" - Skipped (duplicate name): {skipped_duplicate}") + if lite and skipped_state_params > 0: + console.print(f" - Skipped (state params, lite mode): {skipped_state_params}") + console.print(f" - To add: {len(parameters_to_add)}") + + with logfire.span("add_parameters", count=len(parameters_to_add)): + # Build list of parameter dicts for bulk insert + param_rows = [] + param_names = [] # Track (pe_id, name, generated_uuid) + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameters_to_add)} parameters", + total=len(parameters_to_add), + ) + for param in parameters_to_add: + param_uuid = uuid4() + param_rows.append( + { + "id": param_uuid, + "name": param.name, + "label": param.label if hasattr(param, "label") else None, + "description": param.description or "", + "data_type": param.data_type.__name__ + if hasattr(param.data_type, "__name__") + else str(param.data_type), + "unit": param.unit, + "tax_benefit_model_version_id": db_version.id, + "created_at": datetime.now(timezone.utc), + } + ) + param_names.append((param.id, param.name, param_uuid)) + progress.advance(task) + + console.print(f" Inserting {len(param_rows)} parameters...") + bulk_insert( + session, + "parameters", + [ + "id", + "name", + "label", + "description", + "data_type", + "unit", + "tax_benefit_model_version_id", + "created_at", + ], + param_rows, + ) + + # Build param_id_map from pre-generated UUIDs + param_id_map = {pe_id: db_uuid for pe_id, name, db_uuid in param_names} + + console.print( + f" [green]✓[/green] Added {len(parameters_to_add)} parameters" + ) + + # Add parameter values + # Filter to only include values for parameters we added + parameter_values_to_add = [ + pv + for pv in model_version.parameter_values + if pv.parameter.id in param_id_map + ] + console.print(f" Found {len(parameter_values_to_add)} parameter values to add") + + with logfire.span("add_parameter_values", count=len(parameter_values_to_add)): + pv_rows = [] + skipped = 0 + + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task( + f"Preparing {len(parameter_values_to_add)} parameter values", + total=len(parameter_values_to_add), + ) + for pv in parameter_values_to_add: + # Handle Infinity values - skip them as they can't be stored in JSON + if isinstance(pv.value, float) and ( + math.isinf(pv.value) or math.isnan(pv.value) + ): + skipped += 1 + progress.advance(task) + continue + + # Source data has dates swapped (start > end), fix ordering + # Only swap if both dates are set, otherwise keep original + if pv.start_date and pv.end_date: + start = pv.end_date # Swap: source end is our start + end = pv.start_date # Swap: source start is our end + else: + start = pv.start_date + end = pv.end_date + pv_rows.append( + { + "id": uuid4(), + "parameter_id": param_id_map[pv.parameter.id], + "value_json": json.dumps(pv.value), + "start_date": start, + "end_date": end, + "policy_id": None, + "dynamic_id": None, + "created_at": datetime.now(timezone.utc), + } + ) + progress.advance(task) + + console.print(f" Inserting {len(pv_rows)} parameter values...") + bulk_insert( + session, + "parameter_values", + [ + "id", + "parameter_id", + "value_json", + "start_date", + "end_date", + "policy_id", + "dynamic_id", + "created_at", + ], + pv_rows, + ) + + console.print( + f" [green]✓[/green] Added {len(pv_rows)} parameter values" + + (f" (skipped {skipped} invalid)" if skipped else "") + ) + + return db_version diff --git a/scripts/seed_policies.py b/scripts/seed_policies.py new file mode 100644 index 0000000..e57b964 --- /dev/null +++ b/scripts/seed_policies.py @@ -0,0 +1,143 @@ +"""Seed example policy reforms for UK and US.""" + +import time +from datetime import datetime, timezone + +import logfire +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + from policyengine_api.models import ( + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, + ) + + console.print("[bold green]Seeding example policies...[/bold green]\n") + + start = time.time() + with get_session() as session: + with logfire.span("seed_example_policies"): + # Get model versions + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not uk_model or not us_model: + console.print( + "[red]Error: UK or US model not found. Run seed_*_model.py first.[/red]" + ) + return + + uk_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == uk_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + us_version = session.exec( + select(TaxBenefitModelVersion) + .where(TaxBenefitModelVersion.model_id == us_model.id) + .order_by(TaxBenefitModelVersion.created_at.desc()) + ).first() + + # UK example policy: raise basic rate to 22p + uk_policy_name = "UK basic rate 22p" + existing_uk_policy = session.exec( + select(Policy).where(Policy.name == uk_policy_name) + ).first() + + if existing_uk_policy: + console.print(f" Policy '{uk_policy_name}' already exists, skipping") + else: + # Find the basic rate parameter + uk_basic_rate_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.hmrc.income_tax.rates.uk[0].rate", + Parameter.tax_benefit_model_version_id == uk_version.id, + ) + ).first() + + if uk_basic_rate_param: + uk_policy = Policy( + name=uk_policy_name, + description="Raise the UK income tax basic rate from 20p to 22p", + ) + session.add(uk_policy) + session.commit() + session.refresh(uk_policy) + + # Add parameter value (22% = 0.22) + uk_param_value = ParameterValue( + parameter_id=uk_basic_rate_param.id, + value_json={"value": 0.22}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=uk_policy.id, + ) + session.add(uk_param_value) + session.commit() + console.print(f" [green]✓[/green] Created UK policy: {uk_policy_name}") + else: + console.print( + " [yellow]Warning: UK basic rate parameter not found[/yellow]" + ) + + # US example policy: raise first bracket rate to 12% + us_policy_name = "US 12% lowest bracket" + existing_us_policy = session.exec( + select(Policy).where(Policy.name == us_policy_name) + ).first() + + if existing_us_policy: + console.print(f" Policy '{us_policy_name}' already exists, skipping") + else: + # Find the first bracket rate parameter + us_first_bracket_param = session.exec( + select(Parameter).where( + Parameter.name == "gov.irs.income.bracket.rates.1", + Parameter.tax_benefit_model_version_id == us_version.id, + ) + ).first() + + if us_first_bracket_param: + us_policy = Policy( + name=us_policy_name, + description="Raise US federal income tax lowest bracket to 12%", + ) + session.add(us_policy) + session.commit() + session.refresh(us_policy) + + # Add parameter value (12% = 0.12) + us_param_value = ParameterValue( + parameter_id=us_first_bracket_param.id, + value_json={"value": 0.12}, + start_date=datetime(2024, 1, 1, tzinfo=timezone.utc), + end_date=None, + policy_id=us_policy.id, + ) + session.add(us_param_value) + session.commit() + console.print(f" [green]✓[/green] Created US policy: {us_policy_name}") + else: + console.print( + " [yellow]Warning: US first bracket parameter not found[/yellow]" + ) + + console.print("[green]✓[/green] Example policies seeded") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_uk_datasets.py b/scripts/seed_uk_datasets.py new file mode 100644 index 0000000..1754454 --- /dev/null +++ b/scripts/seed_uk_datasets.py @@ -0,0 +1,113 @@ +"""Seed UK datasets (FRS) and upload to S3. + +NOTE: Requires HUGGING_FACE_TOKEN environment variable to be set, +as UK FRS datasets are hosted on a private HuggingFace repository. +""" + +import argparse +import time +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + parser = argparse.ArgumentParser(description="Seed UK datasets") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: only seed FRS 2026", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.uk.datasets import ( + ensure_datasets as ensure_uk_datasets, + ) + + from policyengine_api.models import Dataset, TaxBenefitModel + from policyengine_api.services.storage import upload_dataset_for_seeding + + console.print("[bold green]Seeding UK datasets...[/bold green]\n") + console.print("[yellow]Note: Requires HUGGING_FACE_TOKEN environment variable[/yellow]\n") + + start = time.time() + with get_session() as session: + # Get UK model + uk_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-uk") + ).first() + + if not uk_model: + console.print("[red]Error: UK model not found. Run seed_uk_model.py first.[/red]") + return + + data_folder = str(Path(__file__).parent.parent / "data") + console.print(f" Data folder: {data_folder}") + + # Get datasets + console.print(" Loading UK datasets from policyengine package...") + ds_start = time.time() + uk_datasets = ensure_uk_datasets(data_folder=data_folder) + console.print(f" Loaded {len(uk_datasets)} datasets in {time.time() - ds_start:.1f}s") + + # In lite mode, only upload FRS 2026 + if args.lite: + uk_datasets = { + k: v for k, v in uk_datasets.items() if v.year == 2026 and "frs" in k + } + console.print(f" Lite mode: filtered to {len(uk_datasets)} dataset(s)") + + created = 0 + skipped = 0 + + with logfire.span("seed_uk_datasets", count=len(uk_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("UK datasets", total=len(uk_datasets)) + for name, pe_dataset in uk_datasets.items(): + progress.update(task, description=f"UK: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + upload_start = time.time() + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=uk_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + console.print(f"[green]✓[/green] UK datasets: {created} created, {skipped} skipped") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_uk_model.py b/scripts/seed_uk_model.py new file mode 100644 index 0000000..07543bf --- /dev/null +++ b/scripts/seed_uk_model.py @@ -0,0 +1,33 @@ +"""Seed UK model (variables, parameters, parameter values).""" + +import argparse +import time + +from seed_common import console, get_session, seed_model + + +def main(): + parser = argparse.ArgumentParser(description="Seed UK model") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: skip state parameters", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.uk import uk_latest + + console.print("[bold green]Seeding UK model...[/bold green]\n") + + start = time.time() + with get_session() as session: + version = seed_model(uk_latest, session, lite=args.lite) + console.print(f"[green]✓[/green] UK model seeded: {version.id}") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_us_datasets.py b/scripts/seed_us_datasets.py new file mode 100644 index 0000000..abf1995 --- /dev/null +++ b/scripts/seed_us_datasets.py @@ -0,0 +1,108 @@ +"""Seed US datasets (CPS) and upload to S3.""" + +import argparse +import time +from pathlib import Path + +import logfire +from rich.progress import Progress, SpinnerColumn, TextColumn +from sqlmodel import select + +from seed_common import console, get_session + + +def main(): + parser = argparse.ArgumentParser(description="Seed US datasets") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: only seed CPS 2026", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.us.datasets import ( + ensure_datasets as ensure_us_datasets, + ) + + from policyengine_api.models import Dataset, TaxBenefitModel + from policyengine_api.services.storage import upload_dataset_for_seeding + + console.print("[bold green]Seeding US datasets...[/bold green]\n") + + start = time.time() + with get_session() as session: + # Get US model + us_model = session.exec( + select(TaxBenefitModel).where(TaxBenefitModel.name == "policyengine-us") + ).first() + + if not us_model: + console.print("[red]Error: US model not found. Run seed_us_model.py first.[/red]") + return + + data_folder = str(Path(__file__).parent.parent / "data") + console.print(f" Data folder: {data_folder}") + + # Get datasets + console.print(" Loading US datasets from policyengine package...") + ds_start = time.time() + us_datasets = ensure_us_datasets(data_folder=data_folder) + console.print(f" Loaded {len(us_datasets)} datasets in {time.time() - ds_start:.1f}s") + + # In lite mode, only upload CPS 2026 + if args.lite: + us_datasets = { + k: v for k, v in us_datasets.items() if v.year == 2026 and "cps" in k + } + console.print(f" Lite mode: filtered to {len(us_datasets)} dataset(s)") + + created = 0 + skipped = 0 + + with logfire.span("seed_us_datasets", count=len(us_datasets)): + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + console=console, + ) as progress: + task = progress.add_task("US datasets", total=len(us_datasets)) + for name, pe_dataset in us_datasets.items(): + progress.update(task, description=f"US: {pe_dataset.name}") + + # Check if dataset already exists + existing = session.exec( + select(Dataset).where(Dataset.name == pe_dataset.name) + ).first() + + if existing: + skipped += 1 + progress.advance(task) + continue + + # Upload to S3 + upload_start = time.time() + object_name = upload_dataset_for_seeding(pe_dataset.filepath) + console.print(f" Uploaded {pe_dataset.name} in {time.time() - upload_start:.1f}s") + + # Create database record + db_dataset = Dataset( + name=pe_dataset.name, + description=pe_dataset.description, + filepath=object_name, + year=pe_dataset.year, + tax_benefit_model_id=us_model.id, + ) + session.add(db_dataset) + session.commit() + created += 1 + progress.advance(task) + + console.print(f"[green]✓[/green] US datasets: {created} created, {skipped} skipped") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/scripts/seed_us_model.py b/scripts/seed_us_model.py new file mode 100644 index 0000000..ce8a829 --- /dev/null +++ b/scripts/seed_us_model.py @@ -0,0 +1,33 @@ +"""Seed US model (variables, parameters, parameter values).""" + +import argparse +import time + +from seed_common import console, get_session, seed_model + + +def main(): + parser = argparse.ArgumentParser(description="Seed US model") + parser.add_argument( + "--lite", + action="store_true", + help="Lite mode: skip state parameters", + ) + args = parser.parse_args() + + # Import here to avoid slow import at module level + from policyengine.tax_benefit_models.us import us_latest + + console.print("[bold green]Seeding US model...[/bold green]\n") + + start = time.time() + with get_session() as session: + version = seed_model(us_latest, session, lite=args.lite) + console.print(f"[green]✓[/green] US model seeded: {version.id}") + + elapsed = time.time() - start + console.print(f"\n[bold]Total time: {elapsed:.1f}s[/bold]") + + +if __name__ == "__main__": + main() diff --git a/src/policyengine_api/api/__init__.py b/src/policyengine_api/api/__init__.py index 881af99..c3e0353 100644 --- a/src/policyengine_api/api/__init__.py +++ b/src/policyengine_api/api/__init__.py @@ -9,6 +9,8 @@ datasets, dynamics, household, + household_analysis, + households, outputs, parameter_values, parameters, @@ -16,6 +18,7 @@ simulations, tax_benefit_model_versions, tax_benefit_models, + user_household_associations, variables, ) @@ -33,7 +36,10 @@ api_router.include_router(tax_benefit_model_versions.router) api_router.include_router(change_aggregates.router) api_router.include_router(household.router) +api_router.include_router(household_analysis.router) +api_router.include_router(households.router) api_router.include_router(analysis.router) api_router.include_router(agent.router) +api_router.include_router(user_household_associations.router) __all__ = ["api_router"] diff --git a/src/policyengine_api/api/analysis.py b/src/policyengine_api/api/analysis.py index c9aa86d..10e6fc5 100644 --- a/src/policyengine_api/api/analysis.py +++ b/src/policyengine_api/api/analysis.py @@ -35,6 +35,7 @@ ReportStatus, Simulation, SimulationStatus, + SimulationType, TaxBenefitModel, TaxBenefitModelVersion, ) @@ -138,19 +139,24 @@ def _get_model_version( def _get_deterministic_simulation_id( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, + dataset_id: UUID | None = None, + household_id: UUID | None = None, ) -> UUID: """Generate a deterministic UUID from simulation parameters.""" - key = f"{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + if simulation_type == SimulationType.ECONOMY: + key = f"economy:{dataset_id}:{model_version_id}:{policy_id}:{dynamic_id}" + else: + key = f"household:{household_id}:{model_version_id}:{policy_id}:{dynamic_id}" return uuid5(SIMULATION_NAMESPACE, key) def _get_deterministic_report_id( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, ) -> UUID: """Generate a deterministic UUID from report parameters.""" key = f"{baseline_sim_id}:{reform_sim_id}" @@ -158,15 +164,22 @@ def _get_deterministic_report_id( def _get_or_create_simulation( - dataset_id: UUID, + simulation_type: SimulationType, model_version_id: UUID, policy_id: UUID | None, dynamic_id: UUID | None, session: Session, + dataset_id: UUID | None = None, + household_id: UUID | None = None, ) -> Simulation: """Get existing simulation or create a new one.""" sim_id = _get_deterministic_simulation_id( - dataset_id, model_version_id, policy_id, dynamic_id + simulation_type, + model_version_id, + policy_id, + dynamic_id, + dataset_id=dataset_id, + household_id=household_id, ) existing = session.get(Simulation, sim_id) @@ -175,7 +188,9 @@ def _get_or_create_simulation( simulation = Simulation( id=sim_id, + simulation_type=simulation_type, dataset_id=dataset_id, + household_id=household_id, tax_benefit_model_version_id=model_version_id, policy_id=policy_id, dynamic_id=dynamic_id, @@ -189,8 +204,9 @@ def _get_or_create_simulation( def _get_or_create_report( baseline_sim_id: UUID, - reform_sim_id: UUID, + reform_sim_id: UUID | None, label: str, + report_type: str, session: Session, ) -> Report: """Get existing report or create a new one.""" @@ -203,6 +219,7 @@ def _get_or_create_report( report = Report( id=report_id, label=label, + report_type=report_type, baseline_simulation_id=baseline_sim_id, reform_simulation_id=reform_sim_id, status=ReportStatus.PENDING, @@ -580,19 +597,21 @@ def economic_impact( # Get or create simulations baseline_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=None, dynamic_id=request.dynamic_id, session=session, + dataset_id=request.dataset_id, ) reform_sim = _get_or_create_simulation( - dataset_id=request.dataset_id, + simulation_type=SimulationType.ECONOMY, model_version_id=model_version.id, policy_id=request.policy_id, dynamic_id=request.dynamic_id, session=session, + dataset_id=request.dataset_id, ) # Get or create report @@ -600,7 +619,9 @@ def economic_impact( if request.policy_id: label += f" (policy {request.policy_id})" - report = _get_or_create_report(baseline_sim.id, reform_sim.id, label, session) + report = _get_or_create_report( + baseline_sim.id, reform_sim.id, label, "economy_comparison", session + ) # Trigger computation if report is pending if report.status == ReportStatus.PENDING: diff --git a/src/policyengine_api/api/household_analysis.py b/src/policyengine_api/api/household_analysis.py new file mode 100644 index 0000000..d321be4 --- /dev/null +++ b/src/policyengine_api/api/household_analysis.py @@ -0,0 +1,726 @@ +"""Household impact analysis endpoints. + +Use these endpoints to analyze household-level effects of policy reforms. +Supports single runs (current law) and comparisons (baseline vs reform). + +WORKFLOW: +1. Create a stored household: POST /households +2. Optionally create a reform policy: POST /policies +3. Run analysis: POST /analysis/household-impact (returns report_id) +4. Poll GET /analysis/household-impact/{report_id} until status="completed" +5. Results include baseline_result, reform_result (if comparison), and impact diff +""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Protocol +from uuid import UUID + +import logfire +from fastapi import APIRouter, Depends, HTTPException +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from pydantic import BaseModel, Field +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Policy, + Report, + ReportStatus, + Simulation, + SimulationStatus, + SimulationType, +) +from policyengine_api.services.database import get_session + +from .analysis import ( + _get_model_version, + _get_or_create_report, + _get_or_create_simulation, +) + + +def get_traceparent() -> str | None: + """Get the current W3C traceparent header for distributed tracing.""" + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier) + return carrier.get("traceparent") + + +router = APIRouter(prefix="/analysis", tags=["analysis"]) + + +# ============================================================================= +# Country Strategy Pattern +# ============================================================================= + + +@dataclass(frozen=True) +class CountryConfig: + """Configuration for a country's household calculation.""" + + name: str + entity_types: tuple[str, ...] + + +UK_CONFIG = CountryConfig( + name="uk", + entity_types=("person", "benunit", "household"), +) + +US_CONFIG = CountryConfig( + name="us", + entity_types=( + "person", + "tax_unit", + "spm_unit", + "family", + "marital_unit", + "household", + ), +) + + +def get_country_config(tax_benefit_model_name: str) -> CountryConfig: + """Get country configuration from model name.""" + if tax_benefit_model_name == "policyengine_uk": + return UK_CONFIG + return US_CONFIG + + +class HouseholdCalculator(Protocol): + """Protocol for country-specific household calculators.""" + + def __call__( + self, + household_data: dict[str, Any], + year: int, + policy_data: dict | None, + ) -> dict: ... + + +def calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate UK household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_uk + + return _calculate_household_uk( + people=household_data.get("people", []), + benunit=_ensure_list(household_data.get("benunit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Calculate US household using the existing implementation.""" + from policyengine_api.api.household import _calculate_household_us + + return _calculate_household_us( + people=household_data.get("people", []), + marital_unit=_ensure_list(household_data.get("marital_unit")), + family=_ensure_list(household_data.get("family")), + spm_unit=_ensure_list(household_data.get("spm_unit")), + tax_unit=_ensure_list(household_data.get("tax_unit")), + household=_ensure_list(household_data.get("household")), + year=year, + policy_data=policy_data, + ) + + +def get_calculator(tax_benefit_model_name: str) -> HouseholdCalculator: + """Get the appropriate calculator for a country.""" + if tax_benefit_model_name == "policyengine_uk": + return calculate_uk_household + return calculate_us_household + + +# ============================================================================= +# Data Transformation Helpers +# ============================================================================= + + +def _ensure_list(value: Any) -> list: + """Ensure value is a list; wrap dict in list if needed.""" + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +def _extract_policy_data(policy: Policy | None) -> dict | None: + """Extract policy data from a Policy model into calculation format. + + Returns format expected by _calculate_household_us/_calculate_household_uk: + { + "name": "policy name", + "description": "policy description", + "parameter_values": [ + { + "parameter_name": "gov.irs.credits.ctc...", + "value": 0.16, + "start_date": "2024-01-01T00:00:00+00:00", + "end_date": null + } + ] + } + """ + if not policy or not policy.parameter_values: + return None + + parameter_values = [] + for pv in policy.parameter_values: + if not pv.parameter: + continue + + parameter_values.append({ + "parameter_name": pv.parameter.name, + "value": _extract_value(pv.value_json), + "start_date": _format_date(pv.start_date), + "end_date": _format_date(pv.end_date), + }) + + if not parameter_values: + return None + + return { + "name": policy.name, + "description": policy.description or "", + "parameter_values": parameter_values, + } + + +def _extract_value(value_json: Any) -> Any: + """Extract the actual value from value_json.""" + if isinstance(value_json, dict): + return value_json.get("value") + return value_json + + +def _format_date(date: Any) -> str | None: + """Format a date for the policy data structure.""" + if date is None: + return None + if hasattr(date, "isoformat"): + return date.isoformat() + return str(date) + + +# ============================================================================= +# Impact Computation +# ============================================================================= + + +def compute_variable_diff(baseline_val: Any, reform_val: Any) -> dict | None: + """Compute diff for a single variable if both are numeric.""" + if not isinstance(baseline_val, (int, float)): + return None + if not isinstance(reform_val, (int, float)): + return None + + return { + "baseline": baseline_val, + "reform": reform_val, + "change": reform_val - baseline_val, + } + + +def compute_entity_diff(baseline_entity: dict, reform_entity: dict) -> dict: + """Compute per-variable diffs for a single entity instance.""" + entity_diff = {} + + for key, baseline_val in baseline_entity.items(): + reform_val = reform_entity.get(key) + if reform_val is None: + continue + + diff = compute_variable_diff(baseline_val, reform_val) + if diff is not None: + entity_diff[key] = diff + + return entity_diff + + +def compute_entity_list_diff( + baseline_list: list[dict], + reform_list: list[dict], +) -> list[dict]: + """Compute diffs for a list of entity instances.""" + return [ + compute_entity_diff(b_entity, r_entity) + for b_entity, r_entity in zip(baseline_list, reform_list) + ] + + +def compute_household_impact( + baseline_result: dict, + reform_result: dict, + config: CountryConfig, +) -> dict[str, Any]: + """Compute difference between baseline and reform for all entity types.""" + impact: dict[str, Any] = {} + + for entity in config.entity_types: + baseline_entities = baseline_result.get(entity) + reform_entities = reform_result.get(entity) + + if baseline_entities is None or reform_entities is None: + continue + + impact[entity] = compute_entity_list_diff(baseline_entities, reform_entities) + + return impact + + +# ============================================================================= +# Simulation Execution +# ============================================================================= + + +def mark_simulation_running(simulation: Simulation, session: Session) -> None: + """Mark a simulation as running.""" + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_completed( + simulation: Simulation, + result: dict, + session: Session, +) -> None: + """Mark a simulation as completed with result.""" + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def mark_simulation_failed( + simulation: Simulation, + error: Exception, + session: Session, +) -> None: + """Mark a simulation as failed with error.""" + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(error) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + +def run_household_simulation(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation and store result.""" + simulation = _load_simulation(simulation_id, session) + household = _load_household(simulation.household_id, session) + policy_data = _load_policy_data(simulation.policy_id, session) + + mark_simulation_running(simulation, session) + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + mark_simulation_completed(simulation, result, session) + except Exception as e: + mark_simulation_failed(simulation, e, session) + + +def _load_simulation(simulation_id: UUID, session: Session) -> Simulation: + """Load simulation or raise error.""" + simulation = session.get(Simulation, simulation_id) + if not simulation: + raise ValueError(f"Simulation {simulation_id} not found") + return simulation + + +def _load_household(household_id: UUID | None, session: Session) -> Household: + """Load household or raise error.""" + if not household_id: + raise ValueError("Simulation has no household_id") + + household = session.get(Household, household_id) + if not household: + raise ValueError(f"Household {household_id} not found") + return household + + +def _load_policy_data(policy_id: UUID | None, session: Session) -> dict | None: + """Load and extract policy data if policy_id is set.""" + if not policy_id: + return None + + policy = session.get(Policy, policy_id) + return _extract_policy_data(policy) + + +# ============================================================================= +# Report Orchestration (Async) +# ============================================================================= + + +def _run_local_household_impact(report_id: str, session: Session) -> None: + """Run household impact analysis locally. + + NOTE: This runs synchronously and blocks the HTTP request when running + locally (agent_use_modal=False). This mirrors the economic impact behavior. + True async execution requires Modal. + """ + report = session.get(Report, UUID(report_id)) + if not report: + return + + report.status = ReportStatus.RUNNING + session.add(report) + session.commit() + + try: + # Run baseline simulation + if report.baseline_simulation_id: + _run_simulation_in_session(report.baseline_simulation_id, session) + + # Run reform simulation if present + if report.reform_simulation_id: + _run_simulation_in_session(report.reform_simulation_id, session) + + report.status = ReportStatus.COMPLETED + session.add(report) + session.commit() + except Exception as e: + report.status = ReportStatus.FAILED + report.error_message = str(e) + session.add(report) + session.commit() + + +def _run_simulation_in_session(simulation_id: UUID, session: Session) -> None: + """Run a single household simulation within an existing session.""" + simulation = session.get(Simulation, simulation_id) + if not simulation or simulation.status != SimulationStatus.PENDING: + return + + household = session.get(Household, simulation.household_id) + if not household: + raise ValueError(f"Household {simulation.household_id} not found") + + policy_data = _load_policy_data(simulation.policy_id, session) + + simulation.status = SimulationStatus.RUNNING + simulation.started_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + + try: + calculator = get_calculator(household.tax_benefit_model_name) + result = calculator(household.household_data, household.year, policy_data) + + simulation.household_result = result + simulation.status = SimulationStatus.COMPLETED + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + except Exception as e: + simulation.status = SimulationStatus.FAILED + simulation.error_message = str(e) + simulation.completed_at = datetime.now(timezone.utc) + session.add(simulation) + session.commit() + raise + + +def _trigger_household_impact( + report_id: str, tax_benefit_model_name: str, session: Session | None = None +) -> None: + """Trigger household impact calculation (local or Modal based on settings).""" + from policyengine_api.config import settings + + traceparent = get_traceparent() + + if not settings.agent_use_modal and session is not None: + # Run locally (blocking - see _run_local_household_impact docstring) + _run_local_household_impact(report_id, session) + else: + # Use Modal + import modal + + if tax_benefit_model_name == "policyengine_uk": + fn = modal.Function.from_name("policyengine", "household_impact_uk") + else: + fn = modal.Function.from_name("policyengine", "household_impact_us") + + fn.spawn(report_id=report_id, traceparent=traceparent) + + +# Legacy functions kept for compatibility +def _load_report(report_id: UUID, session: Session) -> Report: + """Load report or raise error.""" + report = session.get(Report, report_id) + if not report: + raise ValueError(f"Report {report_id} not found") + return report + + +# ============================================================================= +# Request/Response Schemas +# ============================================================================= + + +class HouseholdImpactRequest(BaseModel): + """Request for household impact analysis.""" + + household_id: UUID = Field(description="ID of the household to analyze") + policy_id: UUID | None = Field( + default=None, + description="Reform policy ID. If None, runs single calculation under current law.", + ) + dynamic_id: UUID | None = Field( + default=None, + description="Optional behavioural response specification ID", + ) + + +class HouseholdSimulationInfo(BaseModel): + """Info about a household simulation.""" + + id: UUID + status: SimulationStatus + error_message: str | None = None + + +class HouseholdImpactResponse(BaseModel): + """Response for household impact analysis.""" + + report_id: UUID + report_type: str + status: ReportStatus + baseline_simulation: HouseholdSimulationInfo | None = None + reform_simulation: HouseholdSimulationInfo | None = None + baseline_result: dict | None = None + reform_result: dict | None = None + impact: dict | None = None + error_message: str | None = None + + +# ============================================================================= +# Response Building +# ============================================================================= + + +def build_simulation_info( + simulation: Simulation | None, +) -> HouseholdSimulationInfo | None: + """Build simulation info from a simulation.""" + if not simulation: + return None + + return HouseholdSimulationInfo( + id=simulation.id, + status=simulation.status, + error_message=simulation.error_message, + ) + + +def build_household_response( + report: Report, + baseline_sim: Simulation, + reform_sim: Simulation | None, + session: Session, +) -> HouseholdImpactResponse: + """Build response including computed impact for comparisons.""" + baseline_result = baseline_sim.household_result + reform_result = reform_sim.household_result if reform_sim else None + + impact = _compute_impact_if_comparison( + baseline_sim, reform_sim, baseline_result, reform_result, session + ) + + return HouseholdImpactResponse( + report_id=report.id, + report_type=report.report_type or "household_single", + status=report.status, + baseline_simulation=build_simulation_info(baseline_sim), + reform_simulation=build_simulation_info(reform_sim), + baseline_result=baseline_result, + reform_result=reform_result, + impact=impact, + error_message=report.error_message, + ) + + +def _compute_impact_if_comparison( + baseline_sim: Simulation, + reform_sim: Simulation | None, + baseline_result: dict | None, + reform_result: dict | None, + session: Session, +) -> dict | None: + """Compute impact only if this is a comparison with both results.""" + if not reform_sim: + return None + if not baseline_result or not reform_result: + return None + + household = session.get(Household, baseline_sim.household_id) + if not household: + return None + + config = get_country_config(household.tax_benefit_model_name) + return compute_household_impact(baseline_result, reform_result, config) + + +# ============================================================================= +# Validation Helpers +# ============================================================================= + + +def validate_household_exists(household_id: UUID, session: Session) -> Household: + """Validate household exists and return it.""" + household = session.get(Household, household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {household_id} not found", + ) + return household + + +def validate_policy_exists(policy_id: UUID | None, session: Session) -> None: + """Validate policy exists if provided.""" + if not policy_id: + return + + policy = session.get(Policy, policy_id) + if not policy: + raise HTTPException( + status_code=404, + detail=f"Policy {policy_id} not found", + ) + + +# ============================================================================= +# Endpoints +# ============================================================================= + + +@router.post("/household-impact", response_model=HouseholdImpactResponse) +def household_impact( + request: HouseholdImpactRequest, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Run household impact analysis. + + If policy_id is None: single run under current law. + If policy_id is set: comparison (baseline vs reform). + + This is an async operation. The endpoint returns immediately with a report_id + and status="pending". Poll GET /analysis/household-impact/{report_id} until + status="completed" to get results. + """ + household = validate_household_exists(request.household_id, session) + validate_policy_exists(request.policy_id, session) + + model_version = _get_model_version(household.tax_benefit_model_name, session) + + baseline_sim = _create_baseline_simulation( + household, model_version.id, request.dynamic_id, session + ) + reform_sim = _create_reform_simulation( + household, model_version.id, request.policy_id, request.dynamic_id, session + ) + + report_type = "household_comparison" if request.policy_id else "household_single" + report = _get_or_create_report( + baseline_sim_id=baseline_sim.id, + reform_sim_id=reform_sim.id if reform_sim else None, + label=f"Household impact: {household.tax_benefit_model_name}", + report_type=report_type, + session=session, + ) + + if report.status == ReportStatus.PENDING: + with logfire.span("trigger_household_impact", job_id=str(report.id)): + _trigger_household_impact( + str(report.id), household.tax_benefit_model_name, session + ) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +@router.get("/household-impact/{report_id}", response_model=HouseholdImpactResponse) +def get_household_impact( + report_id: UUID, + session: Session = Depends(get_session), +) -> HouseholdImpactResponse: + """Get household impact analysis status and results.""" + report = session.get(Report, report_id) + if not report: + raise HTTPException(status_code=404, detail=f"Report {report_id} not found") + + if not report.baseline_simulation_id: + raise HTTPException( + status_code=500, + detail="Report missing baseline simulation ID", + ) + + baseline_sim = session.get(Simulation, report.baseline_simulation_id) + if not baseline_sim: + raise HTTPException(status_code=500, detail="Baseline simulation data missing") + + reform_sim = None + if report.reform_simulation_id: + reform_sim = session.get(Simulation, report.reform_simulation_id) + + return build_household_response(report, baseline_sim, reform_sim, session) + + +# ============================================================================= +# Simulation Creation Helpers +# ============================================================================= + + +def _create_baseline_simulation( + household: Household, + model_version_id: UUID, + dynamic_id: UUID | None, + session: Session, +) -> Simulation: + """Create baseline simulation (current law, no policy).""" + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=None, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) + + +def _create_reform_simulation( + household: Household, + model_version_id: UUID, + policy_id: UUID | None, + dynamic_id: UUID | None, + session: Session, +) -> Simulation | None: + """Create reform simulation if policy_id is provided.""" + if not policy_id: + return None + + return _get_or_create_simulation( + simulation_type=SimulationType.HOUSEHOLD, + model_version_id=model_version_id, + policy_id=policy_id, + dynamic_id=dynamic_id, + session=session, + household_id=household.id, + ) diff --git a/src/policyengine_api/api/households.py b/src/policyengine_api/api/households.py new file mode 100644 index 0000000..fdee1f7 --- /dev/null +++ b/src/policyengine_api/api/households.py @@ -0,0 +1,119 @@ +"""Stored household CRUD endpoints. + +Households represent saved household definitions that can be reused across +calculations and impact analyses. Create a household once, then reference +it by ID for repeated simulations. + +These endpoints manage stored household *definitions* (people, entity groups, +model name, year). For running calculations on a household, use the +/household/calculate and /household/impact endpoints instead. +""" + +from typing import Any +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import Household, HouseholdCreate, HouseholdRead +from policyengine_api.services.database import get_session + +router = APIRouter(prefix="/households", tags=["households"]) + +_ENTITY_GROUP_KEYS = ( + "tax_unit", + "family", + "spm_unit", + "marital_unit", + "household", + "benunit", +) + + +def _pack_household_data(body: HouseholdCreate) -> dict[str, Any]: + """Pack the flat request fields into a single JSON blob for storage.""" + data: dict[str, Any] = {"people": body.people} + for key in _ENTITY_GROUP_KEYS: + val = getattr(body, key) + if val is not None: + data[key] = val + return data + + +def _to_read(record: Household) -> HouseholdRead: + """Unpack the JSON blob back into the flat response shape.""" + data = record.household_data + return HouseholdRead( + id=record.id, + tax_benefit_model_name=record.tax_benefit_model_name, + year=record.year, + label=record.label, + people=data["people"], + tax_unit=data.get("tax_unit"), + family=data.get("family"), + spm_unit=data.get("spm_unit"), + marital_unit=data.get("marital_unit"), + household=data.get("household"), + benunit=data.get("benunit"), + created_at=record.created_at, + updated_at=record.updated_at, + ) + + +@router.post("/", response_model=HouseholdRead, status_code=201) +def create_household(body: HouseholdCreate, session: Session = Depends(get_session)): + """Create a stored household definition. + + The household data (people + entity groups) is persisted so it can be + retrieved later by ID. Use the returned ID with /household/calculate + or /household/impact to run simulations. + """ + record = Household( + tax_benefit_model_name=body.tax_benefit_model_name, + year=body.year, + label=body.label, + household_data=_pack_household_data(body), + ) + session.add(record) + session.commit() + session.refresh(record) + return _to_read(record) + + +@router.get("/", response_model=list[HouseholdRead]) +def list_households( + tax_benefit_model_name: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List stored households with optional filtering.""" + query = select(Household) + if tax_benefit_model_name is not None: + query = query.where(Household.tax_benefit_model_name == tax_benefit_model_name) + query = query.offset(offset).limit(limit) + records = session.exec(query).all() + return [_to_read(r) for r in records] + + +@router.get("/{household_id}", response_model=HouseholdRead) +def get_household(household_id: UUID, session: Session = Depends(get_session)): + """Get a stored household by ID.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + return _to_read(record) + + +@router.delete("/{household_id}", status_code=204) +def delete_household(household_id: UUID, session: Session = Depends(get_session)): + """Delete a stored household.""" + record = session.get(Household, household_id) + if not record: + raise HTTPException( + status_code=404, detail=f"Household {household_id} not found" + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/api/user_household_associations.py b/src/policyengine_api/api/user_household_associations.py new file mode 100644 index 0000000..fa40e06 --- /dev/null +++ b/src/policyengine_api/api/user_household_associations.py @@ -0,0 +1,125 @@ +"""User-household association endpoints. + +Associations link a user to a stored household definition with metadata +(label, country). A user can have multiple associations to the same +household (e.g. different labels or configurations). +""" + +from datetime import datetime, timezone +from uuid import UUID + +from fastapi import APIRouter, Depends, HTTPException, Query +from sqlmodel import Session, select + +from policyengine_api.models import ( + Household, + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) +from policyengine_api.services.database import get_session + +router = APIRouter( + prefix="/user-household-associations", + tags=["user-household-associations"], +) + + +@router.post("/", response_model=UserHouseholdAssociationRead, status_code=201) +def create_association( + body: UserHouseholdAssociationCreate, + session: Session = Depends(get_session), +): + """Create a user-household association.""" + household = session.get(Household, body.household_id) + if not household: + raise HTTPException( + status_code=404, + detail=f"Household {body.household_id} not found", + ) + + record = UserHouseholdAssociation( + user_id=body.user_id, + household_id=body.household_id, + country_id=body.country_id, + label=body.label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.get("/user/{user_id}", response_model=list[UserHouseholdAssociationRead]) +def list_by_user( + user_id: UUID, + country_id: str | None = None, + limit: int = Query(default=50, le=200), + offset: int = Query(default=0, ge=0), + session: Session = Depends(get_session), +): + """List all associations for a user, optionally filtered by country.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id + ) + if country_id is not None: + query = query.where(UserHouseholdAssociation.country_id == country_id) + query = query.offset(offset).limit(limit) + return session.exec(query).all() + + +@router.get( + "/{user_id}/{household_id}", + response_model=list[UserHouseholdAssociationRead], +) +def list_by_user_and_household( + user_id: UUID, + household_id: UUID, + session: Session = Depends(get_session), +): + """List all associations for a specific user+household pair.""" + query = select(UserHouseholdAssociation).where( + UserHouseholdAssociation.user_id == user_id, + UserHouseholdAssociation.household_id == household_id, + ) + return session.exec(query).all() + + +@router.put("/{association_id}", response_model=UserHouseholdAssociationRead) +def update_association( + association_id: UUID, + body: UserHouseholdAssociationUpdate, + session: Session = Depends(get_session), +): + """Update a user-household association (label).""" + record = session.get(UserHouseholdAssociation, association_id) + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + update_data = body.model_dump(exclude_unset=True) + for key, value in update_data.items(): + setattr(record, key, value) + record.updated_at = datetime.now(timezone.utc) + session.add(record) + session.commit() + session.refresh(record) + return record + + +@router.delete("/{association_id}", status_code=204) +def delete_association( + association_id: UUID, + session: Session = Depends(get_session), +): + """Delete a user-household association.""" + record = session.get(UserHouseholdAssociation, association_id) + if not record: + raise HTTPException( + status_code=404, + detail=f"Association {association_id} not found", + ) + session.delete(record) + session.commit() diff --git a/src/policyengine_api/config/settings.py b/src/policyengine_api/config/settings.py index 76a1ab1..efba345 100644 --- a/src/policyengine_api/config/settings.py +++ b/src/policyengine_api/config/settings.py @@ -40,10 +40,21 @@ class Settings(BaseSettings): @property def database_url(self) -> str: - """Get database URL from Supabase.""" + """Get database URL from Supabase. + + For local development, the database runs on port 54322 (not 54321 which is the API). + Use supabase_db_url to override, or rely on the default local URL. + """ + if self.supabase_db_url: + return self.supabase_db_url + + # For local development, default to the standard Supabase local DB port + if "localhost" in self.supabase_url or "127.0.0.1" in self.supabase_url: + return "postgresql://postgres:postgres@127.0.0.1:54322/postgres" + + # For remote Supabase, construct URL from API URL (usually need supabase_db_url set) return ( - self.supabase_db_url - or self.supabase_url.replace( + self.supabase_url.replace( "http://", "postgresql://postgres:postgres@" ).replace("https://", "postgresql://postgres:postgres@") + "/postgres" diff --git a/src/policyengine_api/models/__init__.py b/src/policyengine_api/models/__init__.py index 4d64c02..c49b457 100644 --- a/src/policyengine_api/models/__init__.py +++ b/src/policyengine_api/models/__init__.py @@ -11,6 +11,7 @@ from .dataset_version import DatasetVersion, DatasetVersionCreate, DatasetVersionRead from .decile_impact import DecileImpact, DecileImpactCreate, DecileImpactRead from .dynamic import Dynamic, DynamicCreate, DynamicRead +from .household import Household, HouseholdCreate, HouseholdRead from .household_job import ( HouseholdJob, HouseholdJobCreate, @@ -35,7 +36,13 @@ ProgramStatisticsRead, ) from .report import Report, ReportCreate, ReportRead, ReportStatus -from .simulation import Simulation, SimulationCreate, SimulationRead, SimulationStatus +from .simulation import ( + Simulation, + SimulationCreate, + SimulationRead, + SimulationStatus, + SimulationType, +) from .tax_benefit_model import ( TaxBenefitModel, TaxBenefitModelCreate, @@ -47,6 +54,12 @@ TaxBenefitModelVersionRead, ) from .user import User, UserCreate, UserRead +from .user_household_association import ( + UserHouseholdAssociation, + UserHouseholdAssociationCreate, + UserHouseholdAssociationRead, + UserHouseholdAssociationUpdate, +) from .variable import Variable, VariableCreate, VariableRead __all__ = [ @@ -72,6 +85,9 @@ "Dynamic", "DynamicCreate", "DynamicRead", + "Household", + "HouseholdCreate", + "HouseholdRead", "HouseholdJob", "HouseholdJobCreate", "HouseholdJobRead", @@ -102,6 +118,7 @@ "SimulationCreate", "SimulationRead", "SimulationStatus", + "SimulationType", "TaxBenefitModel", "TaxBenefitModelCreate", "TaxBenefitModelRead", @@ -110,6 +127,10 @@ "TaxBenefitModelVersionRead", "User", "UserCreate", + "UserHouseholdAssociation", + "UserHouseholdAssociationCreate", + "UserHouseholdAssociationRead", + "UserHouseholdAssociationUpdate", "UserRead", "Variable", "VariableCreate", diff --git a/src/policyengine_api/models/household.py b/src/policyengine_api/models/household.py new file mode 100644 index 0000000..8a96850 --- /dev/null +++ b/src/policyengine_api/models/household.py @@ -0,0 +1,54 @@ +"""Stored household definition model.""" + +from datetime import datetime, timezone +from typing import Any, Literal +from uuid import UUID, uuid4 + +from sqlalchemy import JSON +from sqlmodel import Column, Field, SQLModel + + +class HouseholdBase(SQLModel): + """Base household fields.""" + + tax_benefit_model_name: str + year: int + label: str | None = None + household_data: dict[str, Any] = Field(sa_column=Column(JSON, nullable=False)) + + +class Household(HouseholdBase, table=True): + """Stored household database model.""" + + __tablename__ = "households" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class HouseholdCreate(SQLModel): + """Schema for creating a stored household. + + Accepts the flat structure matching the frontend Household interface: + people as an array, entity groups as optional dicts. + """ + + tax_benefit_model_name: Literal["policyengine_us", "policyengine_uk"] + year: int + label: str | None = None + people: list[dict[str, Any]] + tax_unit: dict[str, Any] | None = None + family: dict[str, Any] | None = None + spm_unit: dict[str, Any] | None = None + marital_unit: dict[str, Any] | None = None + household: dict[str, Any] | None = None + benunit: dict[str, Any] | None = None + + +class HouseholdRead(HouseholdCreate): + """Schema for reading a stored household.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/src/policyengine_api/models/report.py b/src/policyengine_api/models/report.py index ee1b678..bc2cd40 100644 --- a/src/policyengine_api/models/report.py +++ b/src/policyengine_api/models/report.py @@ -19,6 +19,7 @@ class ReportBase(SQLModel): label: str description: str | None = None + report_type: str | None = None user_id: UUID | None = Field(default=None, foreign_key="users.id") markdown: str | None = Field(default=None, sa_column=Column(Text)) parent_report_id: UUID | None = Field(default=None, foreign_key="reports.id") diff --git a/src/policyengine_api/models/simulation.py b/src/policyengine_api/models/simulation.py index b23141e..985db3e 100644 --- a/src/policyengine_api/models/simulation.py +++ b/src/policyengine_api/models/simulation.py @@ -1,13 +1,16 @@ from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING -from uuid import UUID, uuid4 +from typing import TYPE_CHECKING, Any +from sqlalchemy import Column +from sqlalchemy.dialects.postgresql import JSON from sqlmodel import Field, Relationship, SQLModel +from uuid import UUID, uuid4 if TYPE_CHECKING: from .dataset import Dataset from .dynamic import Dynamic + from .household import Household from .policy import Policy from .tax_benefit_model_version import TaxBenefitModelVersion @@ -21,10 +24,19 @@ class SimulationStatus(str, Enum): FAILED = "failed" +class SimulationType(str, Enum): + """Type of simulation.""" + + HOUSEHOLD = "household" + ECONOMY = "economy" + + class SimulationBase(SQLModel): """Base simulation fields.""" - dataset_id: UUID = Field(foreign_key="datasets.id") + simulation_type: SimulationType = SimulationType.ECONOMY + dataset_id: UUID | None = Field(default=None, foreign_key="datasets.id") + household_id: UUID | None = Field(default=None, foreign_key="households.id") policy_id: UUID | None = Field(default=None, foreign_key="policies.id") dynamic_id: UUID | None = Field(default=None, foreign_key="dynamics.id") tax_benefit_model_version_id: UUID = Field( @@ -45,6 +57,9 @@ class Simulation(SimulationBase, table=True): updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) started_at: datetime | None = None completed_at: datetime | None = None + household_result: dict[str, Any] | None = Field( + default=None, sa_column=Column(JSON) + ) # Relationships dataset: "Dataset" = Relationship( @@ -53,6 +68,12 @@ class Simulation(SimulationBase, table=True): "primaryjoin": "Simulation.dataset_id==Dataset.id", } ) + household: "Household" = Relationship( + sa_relationship_kwargs={ + "foreign_keys": "[Simulation.household_id]", + "primaryjoin": "Simulation.household_id==Household.id", + } + ) policy: "Policy" = Relationship() dynamic: "Dynamic" = Relationship() tax_benefit_model_version: "TaxBenefitModelVersion" = Relationship() @@ -78,3 +99,4 @@ class SimulationRead(SimulationBase): updated_at: datetime started_at: datetime | None completed_at: datetime | None + household_result: dict[str, Any] | None = None diff --git a/src/policyengine_api/models/user_household_association.py b/src/policyengine_api/models/user_household_association.py new file mode 100644 index 0000000..9a961cc --- /dev/null +++ b/src/policyengine_api/models/user_household_association.py @@ -0,0 +1,49 @@ +"""User-household association model.""" + +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +from sqlmodel import Field, SQLModel + + +class UserHouseholdAssociationBase(SQLModel): + """Base association fields.""" + + # user_id is a client-generated UUID stored in localStorage, not a foreign key + user_id: UUID = Field(index=True) + household_id: UUID = Field(foreign_key="households.id", index=True) + country_id: str + label: str | None = None + + +class UserHouseholdAssociation(UserHouseholdAssociationBase, table=True): + """User-household association database model.""" + + __tablename__ = "user_household_associations" + + id: UUID = Field(default_factory=uuid4, primary_key=True) + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + + +class UserHouseholdAssociationCreate(SQLModel): + """Schema for creating a user-household association.""" + + user_id: UUID + household_id: UUID + country_id: str + label: str | None = None + + +class UserHouseholdAssociationUpdate(SQLModel): + """Schema for updating a user-household association.""" + + label: str | None = None + + +class UserHouseholdAssociationRead(UserHouseholdAssociationBase): + """Schema for reading a user-household association.""" + + id: UUID + created_at: datetime + updated_at: datetime diff --git a/src/policyengine_api/models/variable.py b/src/policyengine_api/models/variable.py index f163577..eeebddc 100644 --- a/src/policyengine_api/models/variable.py +++ b/src/policyengine_api/models/variable.py @@ -1,5 +1,5 @@ from datetime import datetime, timezone -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from uuid import UUID, uuid4 from sqlmodel import JSON, Column, Field, Relationship, SQLModel @@ -18,6 +18,9 @@ class VariableBase(SQLModel): possible_values: str | None = Field( default=None, sa_column=Column(JSON) ) # Store as JSON list + default_value: Any = Field( + default=None, sa_column=Column(JSON) + ) # Store as JSON (handles int, float, bool, str, etc.) tax_benefit_model_version_id: UUID = Field( foreign_key="tax_benefit_model_versions.id" ) diff --git a/supabase/.temp/cli-latest b/supabase/.temp/cli-latest index 8c68db7..1dd6178 100644 --- a/supabase/.temp/cli-latest +++ b/supabase/.temp/cli-latest @@ -1 +1 @@ -v2.67.1 \ No newline at end of file +v2.75.0 \ No newline at end of file diff --git a/supabase/migrations/20251229000000_add_parameter_values_indexes.sql b/supabase/migrations/20251229000000_add_parameter_values_indexes.sql deleted file mode 100644 index c1713d5..0000000 --- a/supabase/migrations/20251229000000_add_parameter_values_indexes.sql +++ /dev/null @@ -1,16 +0,0 @@ --- Add indexes to parameter_values table for query optimization --- This migration improves query performance for filtering by parameter_id and policy_id - --- Composite index for the most common query pattern (filtering by both) -CREATE INDEX IF NOT EXISTS idx_parameter_values_parameter_policy -ON parameter_values(parameter_id, policy_id); - --- Single index on policy_id for filtering by policy alone -CREATE INDEX IF NOT EXISTS idx_parameter_values_policy -ON parameter_values(policy_id); - --- Partial index for baseline values (policy_id IS NULL) --- This optimizes the common "get current law values" query -CREATE INDEX IF NOT EXISTS idx_parameter_values_baseline -ON parameter_values(parameter_id) -WHERE policy_id IS NULL; diff --git a/supabase/migrations/20260103000000_add_poverty_inequality.sql b/supabase/migrations/20260103000000_add_poverty_inequality.sql deleted file mode 100644 index f315d93..0000000 --- a/supabase/migrations/20260103000000_add_poverty_inequality.sql +++ /dev/null @@ -1,33 +0,0 @@ --- Add poverty and inequality tables for economic analysis - -CREATE TABLE IF NOT EXISTS poverty ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, - report_id UUID REFERENCES reports(id) ON DELETE CASCADE, - poverty_type VARCHAR NOT NULL, - entity VARCHAR NOT NULL DEFAULT 'person', - filter_variable VARCHAR, - headcount FLOAT, - total_population FLOAT, - rate FLOAT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - -CREATE TABLE IF NOT EXISTS inequality ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - simulation_id UUID NOT NULL REFERENCES simulations(id) ON DELETE CASCADE, - report_id UUID REFERENCES reports(id) ON DELETE CASCADE, - income_variable VARCHAR NOT NULL, - entity VARCHAR NOT NULL DEFAULT 'household', - gini FLOAT, - top_10_share FLOAT, - top_1_share FLOAT, - bottom_50_share FLOAT, - created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() -); - --- Indexes for efficient querying -CREATE INDEX IF NOT EXISTS idx_poverty_simulation_id ON poverty(simulation_id); -CREATE INDEX IF NOT EXISTS idx_poverty_report_id ON poverty(report_id); -CREATE INDEX IF NOT EXISTS idx_inequality_simulation_id ON inequality(simulation_id); -CREATE INDEX IF NOT EXISTS idx_inequality_report_id ON inequality(report_id); diff --git a/supabase/migrations/20260111000000_add_aggregate_status.sql b/supabase/migrations/20260111000000_add_aggregate_status.sql deleted file mode 100644 index b190620..0000000 --- a/supabase/migrations/20260111000000_add_aggregate_status.sql +++ /dev/null @@ -1,13 +0,0 @@ --- Add status and error_message columns to aggregates table -ALTER TABLE aggregates -ADD COLUMN IF NOT EXISTS status VARCHAR(20) DEFAULT 'pending', -ADD COLUMN IF NOT EXISTS error_message TEXT; - --- Add status and error_message columns to change_aggregates table -ALTER TABLE change_aggregates -ADD COLUMN IF NOT EXISTS status VARCHAR(20) DEFAULT 'pending', -ADD COLUMN IF NOT EXISTS error_message TEXT; - --- Create indexes for status filtering -CREATE INDEX IF NOT EXISTS idx_aggregates_status ON aggregates(status); -CREATE INDEX IF NOT EXISTS idx_change_aggregates_status ON change_aggregates(status); diff --git a/test_fixtures/fixtures_household_analysis.py b/test_fixtures/fixtures_household_analysis.py new file mode 100644 index 0000000..573930a --- /dev/null +++ b/test_fixtures/fixtures_household_analysis.py @@ -0,0 +1,366 @@ +"""Fixtures and helpers for household analysis endpoint tests.""" + +from typing import Any +from unittest.mock import patch +from uuid import UUID + +import pytest +from sqlmodel import Session + +from policyengine_api.models import ( + Household, + Parameter, + ParameterValue, + Policy, + TaxBenefitModel, + TaxBenefitModelVersion, +) + + +# ============================================================================= +# Sample Calculation Results +# ============================================================================= + + +SAMPLE_UK_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4500.0, + "national_insurance": 2800.0, + "net_income": 27700.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_UK_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 35000.0, + "income_tax": 4000.0, + "national_insurance": 2800.0, + "net_income": 28200.0, + } + ], + "benunit": [ + { + "universal_credit": 0.0, + "child_benefit": 0.0, + } + ], + "household": [ + { + "region": "LONDON", + "council_tax": 1500.0, + } + ], +} + + +SAMPLE_US_BASELINE_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 6000.0, + "fica": 3825.0, + "net_income": 40175.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +SAMPLE_US_REFORM_RESULT: dict[str, Any] = { + "person": [ + { + "age": 30, + "employment_income": 50000.0, + "income_tax": 5500.0, + "fica": 3825.0, + "net_income": 40675.0, + } + ], + "tax_unit": [ + { + "state_code": "CA", + "state_income_tax": 2500.0, + } + ], + "spm_unit": [{"snap": 0.0}], + "family": [{}], + "marital_unit": [{}], + "household": [{"state_fips": 6}], +} + + +# ============================================================================= +# Mock Calculator Functions +# ============================================================================= + + +def mock_calculate_uk_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock UK calculator that returns sample results.""" + if policy_data: + return SAMPLE_UK_REFORM_RESULT + return SAMPLE_UK_BASELINE_RESULT + + +def mock_calculate_us_household( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock US calculator that returns sample results.""" + if policy_data: + return SAMPLE_US_REFORM_RESULT + return SAMPLE_US_BASELINE_RESULT + + +def mock_calculate_household_failing( + household_data: dict[str, Any], + year: int, + policy_data: dict | None, +) -> dict: + """Mock calculator that raises an exception.""" + raise RuntimeError("Calculation failed") + + +# ============================================================================= +# Pytest Fixtures for Mocking +# ============================================================================= + + +@pytest.fixture +def mock_uk_calculator(): + """Fixture that patches UK calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_us_calculator(): + """Fixture that patches US calculator with mock.""" + with patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as mock: + yield mock + + +@pytest.fixture +def mock_calculators(): + """Fixture that patches both UK and US calculators.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_uk_household, + ) as uk_mock, + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_us_household, + ) as us_mock, + ): + yield {"uk": uk_mock, "us": us_mock} + + +@pytest.fixture +def mock_failing_calculator(): + """Fixture that patches calculators to fail.""" + with ( + patch( + "policyengine_api.api.household_analysis.calculate_uk_household", + side_effect=mock_calculate_household_failing, + ), + patch( + "policyengine_api.api.household_analysis.calculate_us_household", + side_effect=mock_calculate_household_failing, + ), + ): + yield + + +# ============================================================================= +# Database Factory Functions +# ============================================================================= + + +def create_tax_benefit_model( + session: Session, + name: str = "policyengine-uk", + description: str = "UK tax benefit model", +) -> TaxBenefitModel: + """Create and persist a TaxBenefitModel record.""" + model = TaxBenefitModel( + name=name, + description=description, + ) + session.add(model) + session.commit() + session.refresh(model) + return model + + +def create_model_version( + session: Session, + model_id: UUID, + version: str = "1.0.0", + description: str = "Test version", +) -> TaxBenefitModelVersion: + """Create and persist a TaxBenefitModelVersion record.""" + model_version = TaxBenefitModelVersion( + model_id=model_id, + version=version, + description=description, + ) + session.add(model_version) + session.commit() + session.refresh(model_version) + return model_version + + +def create_parameter( + session: Session, + model_version_id: UUID, + name: str = "test_parameter", + label: str = "Test Parameter", + description: str = "A test parameter", +) -> Parameter: + """Create and persist a Parameter record.""" + param = Parameter( + tax_benefit_model_version_id=model_version_id, + name=name, + label=label, + description=description, + ) + session.add(param) + session.commit() + session.refresh(param) + return param + + +def create_policy( + session: Session, + model_version_id: UUID, + name: str = "Test Policy", + description: str = "A test policy", +) -> Policy: + """Create and persist a Policy record.""" + policy = Policy( + tax_benefit_model_version_id=model_version_id, + name=name, + description=description, + ) + session.add(policy) + session.commit() + session.refresh(policy) + return policy + + +def create_policy_with_parameter_value( + session: Session, + model_version_id: UUID, + parameter_id: UUID, + value: float, + name: str = "Test Policy", +) -> Policy: + """Create a Policy with an associated ParameterValue.""" + policy = create_policy(session, model_version_id, name=name) + + param_value = ParameterValue( + policy_id=policy.id, + parameter_id=parameter_id, + value_json={"value": value}, + ) + session.add(param_value) + session.commit() + session.refresh(policy) + return policy + + +def create_household_for_analysis( + session: Session, + tax_benefit_model_name: str = "policyengine_uk", + year: int = 2024, + label: str = "Test household for analysis", +) -> Household: + """Create a household suitable for analysis testing.""" + if tax_benefit_model_name == "policyengine_uk": + household_data = { + "people": [{"age": 30, "employment_income": 35000}], + "benunit": {}, + "household": {"region": "LONDON"}, + } + else: + household_data = { + "people": [{"age": 30, "employment_income": 50000}], + "tax_unit": {"state_code": "CA"}, + "family": {}, + "spm_unit": {}, + "marital_unit": {}, + "household": {"state_fips": 6}, + } + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def setup_uk_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create UK model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-uk", description="UK model" + ) + version = create_model_version(session, model.id) + return model, version + + +def setup_us_model_and_version( + session: Session, +) -> tuple[TaxBenefitModel, TaxBenefitModelVersion]: + """Create US model and version for testing.""" + model = create_tax_benefit_model( + session, name="policyengine-us", description="US model" + ) + version = create_model_version(session, model.id) + return model, version diff --git a/test_fixtures/fixtures_households.py b/test_fixtures/fixtures_households.py new file mode 100644 index 0000000..4e676f4 --- /dev/null +++ b/test_fixtures/fixtures_households.py @@ -0,0 +1,66 @@ +"""Fixtures and helpers for household CRUD tests.""" + +from policyengine_api.models import Household + +# ----------------------------------------------------------------------------- +# Request payloads (match HouseholdCreate schema) +# ----------------------------------------------------------------------------- + +MOCK_US_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "label": "US test household", + "people": [ + {"age": 30, "employment_income": 50000}, + {"age": 28, "employment_income": 30000}, + ], + "tax_unit": {}, + "family": {}, + "household": {"state_name": "CA"}, +} + +MOCK_UK_HOUSEHOLD_CREATE = { + "tax_benefit_model_name": "policyengine_uk", + "year": 2024, + "label": "UK test household", + "people": [ + {"age": 40, "employment_income": 35000}, + ], + "benunit": {"is_married": False}, + "household": {"region": "LONDON"}, +} + +MOCK_HOUSEHOLD_MINIMAL = { + "tax_benefit_model_name": "policyengine_us", + "year": 2024, + "people": [{"age": 25}], +} + + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", + people: list | None = None, + **entity_groups, +) -> Household: + """Create and persist a Household record.""" + household_data = {"people": people or [{"age": 30}]} + household_data.update(entity_groups) + + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data=household_data, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/test_fixtures/fixtures_user_household_associations.py b/test_fixtures/fixtures_user_household_associations.py new file mode 100644 index 0000000..66b0835 --- /dev/null +++ b/test_fixtures/fixtures_user_household_associations.py @@ -0,0 +1,62 @@ +"""Fixtures and helpers for user-household association tests.""" + +from uuid import UUID + +from policyengine_api.models import Household, User, UserHouseholdAssociation + +# ----------------------------------------------------------------------------- +# Factory functions +# ----------------------------------------------------------------------------- + + +def create_user( + session, + first_name: str = "Test", + last_name: str = "User", + email: str = "test@example.com", +) -> User: + """Create and persist a User record.""" + record = User(first_name=first_name, last_name=last_name, email=email) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_household( + session, + tax_benefit_model_name: str = "policyengine_us", + year: int = 2024, + label: str | None = "Test household", +) -> Household: + """Create and persist a Household record.""" + record = Household( + tax_benefit_model_name=tax_benefit_model_name, + year=year, + label=label, + household_data={"people": [{"age": 30}]}, + ) + session.add(record) + session.commit() + session.refresh(record) + return record + + +def create_association( + session, + user_id: UUID, + household_id: UUID, + country_id: str = "us", + label: str | None = "My household", +) -> UserHouseholdAssociation: + """Create and persist a UserHouseholdAssociation record.""" + record = UserHouseholdAssociation( + user_id=user_id, + household_id=household_id, + country_id=country_id, + label=label, + ) + session.add(record) + session.commit() + session.refresh(record) + return record diff --git a/tests/test_analysis_household_impact.py b/tests/test_analysis_household_impact.py new file mode 100644 index 0000000..23465c7 --- /dev/null +++ b/tests/test_analysis_household_impact.py @@ -0,0 +1,526 @@ +"""Tests for household impact analysis endpoints.""" + +from datetime import date +from uuid import UUID, uuid4 + +import pytest + +from test_fixtures.fixtures_household_analysis import ( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + create_household_for_analysis, + create_policy, + setup_uk_model_and_version, + setup_us_model_and_version, +) +from policyengine_api.api.household_analysis import ( + UK_CONFIG, + US_CONFIG, + _ensure_list, + _extract_value, + _format_date, + compute_entity_diff, + compute_entity_list_diff, + compute_household_impact, + compute_variable_diff, + get_calculator, + get_country_config, +) +from policyengine_api.models import Report, ReportStatus, Simulation, SimulationType + + +# --------------------------------------------------------------------------- +# Unit tests for helper functions +# --------------------------------------------------------------------------- + + +class TestEnsureList: + """Tests for _ensure_list helper.""" + + def test_none_returns_empty_list(self): + assert _ensure_list(None) == [] + + def test_list_returns_same_list(self): + input_list = [1, 2, 3] + assert _ensure_list(input_list) == input_list + + def test_dict_wrapped_in_list(self): + input_dict = {"key": "value"} + result = _ensure_list(input_dict) + assert result == [input_dict] + + def test_empty_list_returns_empty_list(self): + assert _ensure_list([]) == [] + + +class TestExtractValue: + """Tests for _extract_value helper.""" + + def test_dict_with_value_key(self): + assert _extract_value({"value": 100}) == 100 + + def test_dict_without_value_key(self): + assert _extract_value({"other": 100}) is None + + def test_non_dict_returns_as_is(self): + assert _extract_value(100) == 100 + assert _extract_value("string") == "string" + assert _extract_value([1, 2]) == [1, 2] + + +class TestFormatDate: + """Tests for _format_date helper.""" + + def test_none_returns_none(self): + assert _format_date(None) is None + + def test_date_object_formatted(self): + d = date(2024, 1, 15) + assert _format_date(d) == "2024-01-15" + + def test_string_returns_string(self): + assert _format_date("2024-01-15") == "2024-01-15" + + +class TestComputeVariableDiff: + """Tests for compute_variable_diff helper.""" + + def test_numeric_values_return_diff(self): + result = compute_variable_diff(100, 150) + assert result == {"baseline": 100, "reform": 150, "change": 50} + + def test_negative_change(self): + result = compute_variable_diff(150, 100) + assert result == {"baseline": 150, "reform": 100, "change": -50} + + def test_float_values(self): + result = compute_variable_diff(100.5, 200.5) + assert result == {"baseline": 100.5, "reform": 200.5, "change": 100.0} + + def test_non_numeric_baseline_returns_none(self): + assert compute_variable_diff("string", 100) is None + + def test_non_numeric_reform_returns_none(self): + assert compute_variable_diff(100, "string") is None + + def test_both_non_numeric_returns_none(self): + assert compute_variable_diff("a", "b") is None + + +class TestComputeEntityDiff: + """Tests for compute_entity_diff helper.""" + + def test_computes_diff_for_numeric_keys(self): + baseline = {"income": 1000, "tax": 200, "name": "John"} + reform = {"income": 1000, "tax": 150, "name": "John"} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert result["income"]["change"] == 0 + assert "tax" in result + assert result["tax"]["change"] == -50 + assert "name" not in result + + def test_missing_key_in_reform_skipped(self): + baseline = {"income": 1000, "tax": 200} + reform = {"income": 1000} + result = compute_entity_diff(baseline, reform) + + assert "income" in result + assert "tax" not in result + + def test_empty_entities(self): + assert compute_entity_diff({}, {}) == {} + + +class TestComputeEntityListDiff: + """Tests for compute_entity_list_diff helper.""" + + def test_computes_diff_for_each_pair(self): + baseline_list = [{"income": 100}, {"income": 200}] + reform_list = [{"income": 120}, {"income": 180}] + result = compute_entity_list_diff(baseline_list, reform_list) + + assert len(result) == 2 + assert result[0]["income"]["change"] == 20 + assert result[1]["income"]["change"] == -20 + + def test_empty_lists(self): + assert compute_entity_list_diff([], []) == [] + + +class TestComputeHouseholdImpact: + """Tests for compute_household_impact helper.""" + + def test_uk_household_impact(self): + result = compute_household_impact( + SAMPLE_UK_BASELINE_RESULT, + SAMPLE_UK_REFORM_RESULT, + UK_CONFIG, + ) + + assert "person" in result + assert "benunit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert "income_tax" in person_diff + assert person_diff["income_tax"]["baseline"] == 4500.0 + assert person_diff["income_tax"]["reform"] == 4000.0 + assert person_diff["income_tax"]["change"] == -500.0 + + def test_us_household_impact(self): + result = compute_household_impact( + SAMPLE_US_BASELINE_RESULT, + SAMPLE_US_REFORM_RESULT, + US_CONFIG, + ) + + assert "person" in result + assert "tax_unit" in result + assert "spm_unit" in result + assert "family" in result + assert "marital_unit" in result + assert "household" in result + + # Check person income_tax changed + person_diff = result["person"][0] + assert person_diff["income_tax"]["change"] == -500.0 + + def test_missing_entity_skipped(self): + baseline = {"person": [{"income": 100}]} + reform = {"person": [{"income": 120}]} + result = compute_household_impact(baseline, reform, UK_CONFIG) + + assert "person" in result + assert "benunit" not in result + assert "household" not in result + + +class TestGetCountryConfig: + """Tests for get_country_config helper.""" + + def test_uk_model_returns_uk_config(self): + config = get_country_config("policyengine_uk") + assert config == UK_CONFIG + assert config.name == "uk" + assert "benunit" in config.entity_types + + def test_us_model_returns_us_config(self): + config = get_country_config("policyengine_us") + assert config == US_CONFIG + assert config.name == "us" + assert "tax_unit" in config.entity_types + + def test_unknown_model_defaults_to_us(self): + config = get_country_config("unknown_model") + assert config == US_CONFIG + + +class TestGetCalculator: + """Tests for get_calculator helper.""" + + def test_uk_model_returns_uk_calculator(self): + from policyengine_api.api.household_analysis import calculate_uk_household + + calc = get_calculator("policyengine_uk") + assert calc == calculate_uk_household + + def test_us_model_returns_us_calculator(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("policyengine_us") + assert calc == calculate_us_household + + def test_unknown_model_defaults_to_us(self): + from policyengine_api.api.household_analysis import calculate_us_household + + calc = get_calculator("unknown_model") + assert calc == calculate_us_household + + +# --------------------------------------------------------------------------- +# Validation tests (no database required beyond session fixture) +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactValidation: + """Tests for request validation.""" + + def test_missing_household_id(self, client): + """Test that missing household_id returns 422.""" + response = client.post( + "/analysis/household-impact", + json={}, + ) + assert response.status_code == 422 + + def test_invalid_uuid(self, client): + """Test that invalid UUID returns 422.""" + response = client.post( + "/analysis/household-impact", + json={ + "household_id": "not-a-uuid", + }, + ) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# 404 tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactNotFound: + """Tests for 404 responses.""" + + def test_household_not_found(self, client, session): + """Test that non-existent household returns 404.""" + # Need model for the model version lookup + setup_uk_model_and_version(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_policy_not_found(self, client, session): + """Test that non-existent policy returns 404.""" + setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(uuid4()), + }, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + def test_get_report_not_found(self, client): + """Test that GET with non-existent report_id returns 404.""" + response = client.get(f"/analysis/household-impact/{uuid4()}") + assert response.status_code == 404 + + +# --------------------------------------------------------------------------- +# Record creation tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactRecordCreation: + """Tests for correct record creation.""" + + def test_single_run_creates_one_simulation(self, client, session): + """Single run (no policy_id) creates one simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + # May fail during calculation since policyengine not available, + # but should create records + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is None + + def test_comparison_creates_two_simulations(self, client, session): + """Comparison (with policy_id) creates two simulations.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, version.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + assert "report_id" in data + assert data["report_type"] == "household_comparison" + assert data["baseline_simulation"] is not None + assert data["reform_simulation"] is not None + + def test_simulation_type_is_household(self, client, session): + """Created simulations have simulation_type=HOUSEHOLD.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + }, + ) + data = response.json() + + # Check simulation in database (convert string to UUID for query) + sim_id = UUID(data["baseline_simulation"]["id"]) + sim = session.get(Simulation, sim_id) + assert sim is not None + assert sim.simulation_type == SimulationType.HOUSEHOLD + assert sim.household_id == household.id + assert sim.dataset_id is None + + def test_report_links_simulations(self, client, session): + """Report correctly links baseline and reform simulations.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy = create_policy(session, version.id) + + response = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy.id), + }, + ) + data = response.json() + + # Check report in database (convert string to UUID for query) + report = session.get(Report, UUID(data["report_id"])) + assert report is not None + assert report.baseline_simulation_id == UUID(data["baseline_simulation"]["id"]) + assert report.reform_simulation_id == UUID(data["reform_simulation"]["id"]) + assert report.report_type == "household_comparison" + + +# --------------------------------------------------------------------------- +# Deduplication tests +# --------------------------------------------------------------------------- + + +class TestHouseholdImpactDeduplication: + """Tests for simulation/report deduplication.""" + + def test_same_request_returns_same_simulation(self, client, session): + """Same household + same parameters returns same simulation ID.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # First request + response1 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data1 = response1.json() + + # Second request with same parameters + response2 = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data2 = response2.json() + + # Should return same IDs + assert data1["report_id"] == data2["report_id"] + assert data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + + def test_different_policy_creates_different_simulation(self, client, session): + """Different policy creates different simulation.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + policy1 = create_policy(session, version.id, name="Policy 1") + policy2 = create_policy(session, version.id, name="Policy 2") + + # Request with policy1 + response1 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy1.id), + }, + ) + data1 = response1.json() + + # Request with policy2 + response2 = client.post( + "/analysis/household-impact", + json={ + "household_id": str(household.id), + "policy_id": str(policy2.id), + }, + ) + data2 = response2.json() + + # Reports should be different + assert data1["report_id"] != data2["report_id"] + # Reform simulations should be different + assert ( + data1["reform_simulation"]["id"] != data2["reform_simulation"]["id"] + ) + # Baseline simulations should be the same (same household, no policy) + assert ( + data1["baseline_simulation"]["id"] == data2["baseline_simulation"]["id"] + ) + + +# --------------------------------------------------------------------------- +# GET endpoint tests +# --------------------------------------------------------------------------- + + +class TestGetHouseholdImpact: + """Tests for GET /analysis/household-impact/{report_id}.""" + + def test_get_returns_report_data(self, client, session): + """GET returns report with simulation info.""" + _, version = setup_uk_model_and_version(session) + household = create_household_for_analysis(session) + + # Create report via POST + post_response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + report_id = post_response.json()["report_id"] + + # GET the report + get_response = client.get(f"/analysis/household-impact/{report_id}") + assert get_response.status_code == 200 + + data = get_response.json() + assert data["report_id"] == report_id + assert data["report_type"] == "household_single" + assert data["baseline_simulation"] is not None + + +# --------------------------------------------------------------------------- +# US household tests +# --------------------------------------------------------------------------- + + +class TestUSHouseholdImpact: + """Tests specific to US households.""" + + def test_us_household_creates_simulation(self, client, session): + """US household creates simulation with correct model.""" + _, version = setup_us_model_and_version(session) + household = create_household_for_analysis( + session, tax_benefit_model_name="policyengine_us" + ) + + response = client.post( + "/analysis/household-impact", + json={"household_id": str(household.id)}, + ) + data = response.json() + assert "report_id" in data + assert data["baseline_simulation"] is not None diff --git a/tests/test_household.py b/tests/test_household.py index a7248b3..eab15a5 100644 --- a/tests/test_household.py +++ b/tests/test_household.py @@ -289,5 +289,202 @@ def test_missing_people(self): assert response.status_code == 422 +class TestUSPolicyReform: + """Tests for US household calculations with policy reforms.""" + + def _get_us_model_id(self) -> str: + """Get the US tax benefit model ID.""" + response = client.get("/tax-benefit-models/") + assert response.status_code == 200 + models = response.json() + for model in models: + if "us" in model["name"].lower(): + return model["id"] + raise AssertionError("US model not found") + + def _get_parameter_id(self, model_id: str, param_name: str) -> str: + """Get a parameter ID by name.""" + response = client.get( + f"/parameters/?tax_benefit_model_id={model_id}&limit=10000" + ) + assert response.status_code == 200 + params = response.json() + for param in params: + if param["name"] == param_name: + return param["id"] + raise AssertionError(f"Parameter {param_name} not found") + + def _create_policy(self, param_id: str, value: float) -> str: + """Create a policy with a parameter value.""" + response = client.post( + "/policies/", + json={ + "name": "Test Reform", + "description": "Test reform for household calculation", + "parameter_values": [ + { + "parameter_id": param_id, + "value_json": value, + "start_date": "2024-01-01T00:00:00Z", + } + ], + }, + ) + assert response.status_code == 200 + return response.json()["id"] + + def test_us_reform_changes_household_net_income(self): + """Test that a US policy reform changes household net income. + + This test verifies the fix for the US reform application bug where + reforms were not being applied correctly due to the shared singleton + TaxBenefitSystem in policyengine-us. + """ + # Get the US model and a UBI parameter + model_id = self._get_us_model_id() + param_name = "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount" + param_id = self._get_parameter_id(model_id, param_name) + + # Create a policy with $1000 UBI for older adults + policy_id = self._create_policy(param_id, 1000) + + # Run baseline calculation (no policy) + baseline_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 40, "employment_income": 70000}], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + }, + ) + assert baseline_response.status_code == 200 + baseline_data = _poll_job(baseline_response.json()["job_id"]) + baseline_net_income = baseline_data["result"]["household"][0][ + "household_net_income" + ] + + # Run reform calculation (with UBI policy) + reform_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_us", + "people": [{"age": 40, "employment_income": 70000}], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + "policy_id": policy_id, + }, + ) + assert reform_response.status_code == 200 + reform_data = _poll_job(reform_response.json()["job_id"]) + reform_net_income = reform_data["result"]["household"][0][ + "household_net_income" + ] + + # Verify the reform increased net income by approximately $1000 + difference = reform_net_income - baseline_net_income + assert abs(difference - 1000) < 1, ( + f"Expected ~$1000 difference, got ${difference:.2f}. " + f"Baseline: ${baseline_net_income:.2f}, Reform: ${reform_net_income:.2f}" + ) + + +class TestUKPolicyReform: + """Tests for UK household calculations with policy reforms.""" + + def _get_uk_model_id(self) -> str | None: + """Get the UK tax benefit model ID, or None if not seeded.""" + response = client.get("/tax-benefit-models/") + assert response.status_code == 200 + models = response.json() + for model in models: + if "uk" in model["name"].lower(): + return model["id"] + return None + + def _get_parameter_id(self, model_id: str, param_name: str) -> str: + """Get a parameter ID by name.""" + response = client.get( + f"/parameters/?tax_benefit_model_id={model_id}&limit=10000" + ) + assert response.status_code == 200 + params = response.json() + for param in params: + if param["name"] == param_name: + return param["id"] + raise AssertionError(f"Parameter {param_name} not found") + + def _create_policy(self, param_id: str, value: float) -> str: + """Create a policy with a parameter value.""" + response = client.post( + "/policies/", + json={ + "name": "Test UK Reform", + "description": "Test reform for UK household calculation", + "parameter_values": [ + { + "parameter_id": param_id, + "value_json": value, + "start_date": "2026-01-01T00:00:00Z", + } + ], + }, + ) + assert response.status_code == 200 + return response.json()["id"] + + def test_uk_reform_changes_household_net_income(self): + """Test that a UK policy reform changes household net income.""" + # Get the UK model and a UBI parameter + model_id = self._get_uk_model_id() + if model_id is None: + pytest.skip("UK model not seeded in database") + param_name = "gov.contrib.ubi_center.basic_income.adult" + param_id = self._get_parameter_id(model_id, param_name) + + # Create a policy with £1000 UBI for adults + policy_id = self._create_policy(param_id, 1000) + + # Run baseline calculation (no policy) + baseline_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + }, + ) + assert baseline_response.status_code == 200 + baseline_data = _poll_job(baseline_response.json()["job_id"]) + baseline_net_income = baseline_data["result"]["household"][0][ + "household_net_income" + ] + + # Run reform calculation (with UBI policy) + reform_response = client.post( + "/household/calculate", + json={ + "tax_benefit_model_name": "policyengine_uk", + "people": [{"age": 30, "employment_income": 30000}], + "year": 2026, + "policy_id": policy_id, + }, + ) + assert reform_response.status_code == 200 + reform_data = _poll_job(reform_response.json()["job_id"]) + reform_net_income = reform_data["result"]["household"][0][ + "household_net_income" + ] + + # Verify the reform increased net income + difference = reform_net_income - baseline_net_income + assert difference > 0, ( + f"Expected positive difference, got £{difference:.2f}. " + f"Baseline: £{baseline_net_income:.2f}, Reform: £{reform_net_income:.2f}" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_household_calculation.py b/tests/test_household_calculation.py new file mode 100644 index 0000000..e4fc2a5 --- /dev/null +++ b/tests/test_household_calculation.py @@ -0,0 +1,128 @@ +"""Unit tests for household calculation functions. + +These tests verify that the calculation functions work correctly with policy reforms, +without requiring database setup or API calls. +""" + +import pytest + +from policyengine_api.api.household import _calculate_household_us + + +class TestUSHouseholdCalculation: + """Unit tests for US household calculation with policy reforms.""" + + @pytest.mark.slow + def test_baseline_calculation(self): + """Test basic US household calculation without policy.""" + result = _calculate_household_us( + people=[{"employment_income": 70000, "age": 40}], + marital_unit=[], + family=[], + spm_unit=[], + tax_unit=[{"state_code": "CA"}], + household=[{"state_fips": 6}], + year=2024, + policy_data=None, + ) + + assert "person" in result + assert "household" in result + assert "tax_unit" in result + assert len(result["person"]) == 1 + assert result["tax_unit"][0]["income_tax"] > 0 + + @pytest.mark.slow + def test_reform_changes_net_income(self): + """Test that a US policy reform changes household net income. + + This test verifies the fix for the US reform application bug where + reforms were not being applied correctly due to the shared singleton + TaxBenefitSystem in policyengine-us. + """ + household_args = { + "people": [{"employment_income": 70000, "age": 40}], + "marital_unit": [], + "family": [], + "spm_unit": [], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + } + + # Calculate baseline (no policy) + baseline = _calculate_household_us(**household_args, policy_data=None) + baseline_net_income = baseline["household"][0]["household_net_income"] + + # Calculate with $1000 UBI reform + policy_data = { + "name": "Test UBI", + "description": "Test UBI reform", + "parameter_values": [ + { + "parameter_name": "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount", + "value": 1000, + "start_date": "2024-01-01T00:00:00", + "end_date": None, + } + ], + } + reform = _calculate_household_us(**household_args, policy_data=policy_data) + reform_net_income = reform["household"][0]["household_net_income"] + + # Verify the reform increased net income by exactly $1000 + difference = reform_net_income - baseline_net_income + assert abs(difference - 1000) < 1, ( + f"Expected ~$1000 difference, got ${difference:.2f}. " + f"Baseline: ${baseline_net_income:.2f}, Reform: ${reform_net_income:.2f}" + ) + + @pytest.mark.slow + def test_reform_does_not_affect_baseline(self): + """Test that running reform doesn't pollute baseline calculations. + + This is a regression test for the singleton pollution bug where running + a reform calculation would affect subsequent baseline calculations. + """ + household_args = { + "people": [{"employment_income": 70000, "age": 40}], + "marital_unit": [], + "family": [], + "spm_unit": [], + "tax_unit": [{"state_code": "CA"}], + "household": [{"state_fips": 6}], + "year": 2024, + } + + # First baseline + baseline1 = _calculate_household_us(**household_args, policy_data=None) + baseline1_net_income = baseline1["household"][0]["household_net_income"] + + # Run reform + policy_data = { + "name": "Test UBI", + "description": "Test UBI reform", + "parameter_values": [ + { + "parameter_name": "gov.contrib.ubi_center.basic_income.amount.person.by_age[3].amount", + "value": 5000, + "start_date": "2024-01-01T00:00:00", + "end_date": None, + } + ], + } + _calculate_household_us(**household_args, policy_data=policy_data) + + # Second baseline - should be same as first + baseline2 = _calculate_household_us(**household_args, policy_data=None) + baseline2_net_income = baseline2["household"][0]["household_net_income"] + + # Verify baselines are identical + assert abs(baseline1_net_income - baseline2_net_income) < 0.01, ( + f"Baseline changed after reform calculation! " + f"Before: ${baseline1_net_income:.2f}, After: ${baseline2_net_income:.2f}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_households.py b/tests/test_households.py new file mode 100644 index 0000000..4c60062 --- /dev/null +++ b/tests/test_households.py @@ -0,0 +1,155 @@ +"""Tests for stored household CRUD endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_households import ( + MOCK_HOUSEHOLD_MINIMAL, + MOCK_UK_HOUSEHOLD_CREATE, + MOCK_US_HOUSEHOLD_CREATE, + create_household, +) + +# --------------------------------------------------------------------------- +# POST /households +# --------------------------------------------------------------------------- + + +def test_create_us_household(client): + """Create a US household returns 201 with id and timestamps.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["tax_benefit_model_name"] == "policyengine_us" + assert data["year"] == 2024 + assert data["label"] == "US test household" + + +def test_create_household_returns_people_and_entities(client): + """Created household response includes people and entity groups.""" + response = client.post("/households", json=MOCK_US_HOUSEHOLD_CREATE) + data = response.json() + assert len(data["people"]) == 2 + assert data["people"][0]["age"] == 30 + assert data["people"][0]["employment_income"] == 50000 + assert data["household"] == {"state_name": "CA"} + assert data["tax_unit"] == {} + assert data["family"] == {} + + +def test_create_uk_household(client): + """Create a UK household with benunit.""" + response = client.post("/households", json=MOCK_UK_HOUSEHOLD_CREATE) + assert response.status_code == 201 + data = response.json() + assert data["tax_benefit_model_name"] == "policyengine_uk" + assert data["benunit"] == {"is_married": False} + assert data["household"] == {"region": "LONDON"} + + +def test_create_household_minimal(client): + """Create a household with minimal fields.""" + response = client.post("/households", json=MOCK_HOUSEHOLD_MINIMAL) + assert response.status_code == 201 + data = response.json() + assert data["label"] is None + assert data["tax_unit"] is None + assert data["benunit"] is None + + +def test_create_household_invalid_model_name(client): + """Reject invalid tax_benefit_model_name.""" + payload = {**MOCK_HOUSEHOLD_MINIMAL, "tax_benefit_model_name": "invalid"} + response = client.post("/households", json=payload) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /households/{id} +# --------------------------------------------------------------------------- + + +def test_get_household(client, session): + """Get a stored household by ID.""" + record = create_household(session) + response = client.get(f"/households/{record.id}") + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(record.id) + assert data["tax_benefit_model_name"] == "policyengine_us" + + +def test_get_household_not_found(client): + """Get a non-existent household returns 404.""" + fake_id = uuid4() + response = client.get(f"/households/{fake_id}") + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /households +# --------------------------------------------------------------------------- + + +def test_list_households_empty(client): + """List households returns empty list when none exist.""" + response = client.get("/households") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_households_with_data(client, session): + """List households returns all stored households.""" + create_household(session, label="first") + create_household(session, label="second") + response = client.get("/households") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_households_filter_by_model_name(client, session): + """Filter households by tax_benefit_model_name.""" + create_household(session, tax_benefit_model_name="policyengine_us") + create_household(session, tax_benefit_model_name="policyengine_uk") + response = client.get( + "/households", params={"tax_benefit_model_name": "policyengine_uk"} + ) + data = response.json() + assert len(data) == 1 + assert data[0]["tax_benefit_model_name"] == "policyengine_uk" + + +def test_list_households_limit_and_offset(client, session): + """Respect limit and offset pagination.""" + for i in range(5): + create_household(session, label=f"household-{i}") + response = client.get("/households", params={"limit": 2, "offset": 1}) + data = response.json() + assert len(data) == 2 + + +# --------------------------------------------------------------------------- +# DELETE /households/{id} +# --------------------------------------------------------------------------- + + +def test_delete_household(client, session): + """Delete a household returns 204.""" + record = create_household(session) + response = client.delete(f"/households/{record.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/households/{record.id}") + assert response.status_code == 404 + + +def test_delete_household_not_found(client): + """Delete a non-existent household returns 404.""" + fake_id = uuid4() + response = client.delete(f"/households/{fake_id}") + assert response.status_code == 404 diff --git a/tests/test_models.py b/tests/test_models.py index 0f84140..e3a83d9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,9 +6,11 @@ AggregateOutput, AggregateType, Dataset, + Household, Policy, Simulation, SimulationStatus, + Variable, ) @@ -66,3 +68,90 @@ def test_aggregate_output_creation(): assert output.simulation_id == simulation_id assert output.aggregate_type == AggregateType.SUM assert output.result is None + + +def test_variable_creation_with_default_value(): + """Test variable model creation with default_value field.""" + model_version_id = uuid4() + variable = Variable( + name="age", + entity="person", + description="Age of the person", + data_type="int", + default_value=40, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.name == "age" + assert variable.entity == "person" + assert variable.data_type == "int" + assert variable.default_value == 40 + assert variable.id is not None + + +def test_variable_with_float_default_value(): + """Test variable model with float default value.""" + model_version_id = uuid4() + variable = Variable( + name="employment_income", + entity="person", + data_type="float", + default_value=0.0, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value == 0.0 + + +def test_variable_with_bool_default_value(): + """Test variable model with boolean default value.""" + model_version_id = uuid4() + variable = Variable( + name="is_disabled", + entity="person", + data_type="bool", + default_value=False, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value is False + + +def test_variable_with_string_default_value(): + """Test variable model with string default value (enum).""" + model_version_id = uuid4() + variable = Variable( + name="state_name", + entity="household", + data_type="Enum", + default_value="CA", + possible_values=["CA", "NY", "TX"], + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value == "CA" + assert variable.possible_values == ["CA", "NY", "TX"] + + +def test_variable_with_null_default_value(): + """Test variable model with null default value.""" + model_version_id = uuid4() + variable = Variable( + name="optional_field", + entity="person", + data_type="str", + default_value=None, + tax_benefit_model_version_id=model_version_id, + ) + assert variable.default_value is None + + +def test_household_creation(): + """Test household model creation.""" + household = Household( + tax_benefit_model_name="policyengine_us", + year=2024, + label="Test household", + household_data={"people": [{"age": 30}], "household": {}}, + ) + assert household.household_data == {"people": [{"age": 30}], "household": {}} + assert household.label == "Test household" + assert household.tax_benefit_model_name == "policyengine_us" + assert household.year == 2024 + assert household.id is not None diff --git a/tests/test_user_household_associations.py b/tests/test_user_household_associations.py new file mode 100644 index 0000000..25d8989 --- /dev/null +++ b/tests/test_user_household_associations.py @@ -0,0 +1,189 @@ +"""Tests for user-household association endpoints.""" + +from uuid import uuid4 + +from test_fixtures.fixtures_user_household_associations import ( + create_association, + create_household, + create_user, +) + +# --------------------------------------------------------------------------- +# POST /user-household-associations +# --------------------------------------------------------------------------- + + +def test_create_association(client, session): + """Create an association returns 201 with id and timestamps.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "My US household", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert "created_at" in data + assert "updated_at" in data + assert data["user_id"] == str(user.id) + assert data["household_id"] == str(household.id) + assert data["country_id"] == "us" + assert data["label"] == "My US household" + + +def test_create_association_allows_duplicates(client, session): + """Multiple associations to the same household are allowed.""" + user = create_user(session) + household = create_household(session) + payload = { + "user_id": str(user.id), + "household_id": str(household.id), + "country_id": "us", + "label": "First label", + } + r1 = client.post("/user-household-associations", json=payload) + assert r1.status_code == 201 + + payload["label"] = "Second label" + r2 = client.post("/user-household-associations", json=payload) + assert r2.status_code == 201 + assert r1.json()["id"] != r2.json()["id"] + + +def test_create_association_household_not_found(client, session): + """Creating with a non-existent household returns 404.""" + user = create_user(session) + payload = { + "user_id": str(user.id), + "household_id": str(uuid4()), + "country_id": "us", + } + response = client.post("/user-household-associations", json=payload) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/user/{user_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_empty(client): + """List associations for a user with none returns empty list.""" + response = client.get(f"/user-household-associations/user/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +def test_list_by_user(client, session): + """List all associations for a user.""" + user = create_user(session) + h1 = create_household(session, label="H1") + h2 = create_household(session, label="H2") + create_association(session, user.id, h1.id, label="First") + create_association(session, user.id, h2.id, label="Second") + + response = client.get(f"/user-household-associations/user/{user.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_filter_country(client, session): + """Filter associations by country_id.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, country_id="us") + create_association(session, user.id, household.id, country_id="uk") + + response = client.get( + f"/user-household-associations/user/{user.id}", + params={"country_id": "uk"}, + ) + data = response.json() + assert len(data) == 1 + assert data[0]["country_id"] == "uk" + + +# --------------------------------------------------------------------------- +# GET /user-household-associations/{user_id}/{household_id} +# --------------------------------------------------------------------------- + + +def test_list_by_user_and_household(client, session): + """List associations for a specific user+household pair.""" + user = create_user(session) + household = create_household(session) + create_association(session, user.id, household.id, label="Label A") + create_association(session, user.id, household.id, label="Label B") + + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + + +def test_list_by_user_and_household_empty(client): + """Returns empty list when no associations exist for the pair.""" + response = client.get(f"/user-household-associations/{uuid4()}/{uuid4()}") + assert response.status_code == 200 + assert response.json() == [] + + +# --------------------------------------------------------------------------- +# PUT /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_update_association_label(client, session): + """Update label and verify updated_at changes.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id, label="Old") + + response = client.put( + f"/user-household-associations/{assoc.id}", + json={"label": "New label"}, + ) + assert response.status_code == 200 + data = response.json() + assert data["label"] == "New label" + + +def test_update_association_not_found(client): + """Update a non-existent association returns 404.""" + response = client.put( + f"/user-household-associations/{uuid4()}", + json={"label": "Something"}, + ) + assert response.status_code == 404 + assert "not found" in response.json()["detail"] + + +# --------------------------------------------------------------------------- +# DELETE /user-household-associations/{association_id} +# --------------------------------------------------------------------------- + + +def test_delete_association(client, session): + """Delete an association returns 204.""" + user = create_user(session) + household = create_household(session) + assoc = create_association(session, user.id, household.id) + + response = client.delete(f"/user-household-associations/{assoc.id}") + assert response.status_code == 204 + + # Confirm it's gone + response = client.get(f"/user-household-associations/{user.id}/{household.id}") + assert response.json() == [] + + +def test_delete_association_not_found(client): + """Delete a non-existent association returns 404.""" + response = client.delete(f"/user-household-associations/{uuid4()}") + assert response.status_code == 404 diff --git a/uv.lock b/uv.lock index 094ebf8..466caf4 100644 --- a/uv.lock +++ b/uv.lock @@ -91,6 +91,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" }, ] +[[package]] +name = "alembic" +version = "1.18.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/79/41/ab8f624929847b49f84955c594b165855efd829b0c271e1a8cac694138e5/alembic-1.18.3.tar.gz", hash = "sha256:1212aa3778626f2b0f0aa6dd4e99a5f99b94bd25a0c1ac0bba3be65e081e50b0", size = 2052564, upload-time = "2026-01-29T20:24:15.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/45/8e/d79281f323e7469b060f15bd229e48d7cdd219559e67e71c013720a88340/alembic-1.18.3-py3-none-any.whl", hash = "sha256:12a0359bfc068a4ecbb9b3b02cf77856033abfdb59e4a5aca08b7eacd7b74ddd", size = 262282, upload-time = "2026-01-29T20:24:17.488Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1057,6 +1071,18 @@ sqlalchemy = [ { name = "opentelemetry-instrumentation-sqlalchemy" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markdown-it-py" version = "4.0.0" @@ -1757,6 +1783,7 @@ name = "policyengine-api-v2" version = "0.1.0" source = { editable = "." } dependencies = [ + { name = "alembic" }, { name = "anthropic" }, { name = "boto3" }, { name = "fastapi" }, @@ -1793,6 +1820,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "alembic", specifier = ">=1.13.0" }, { name = "anthropic", specifier = ">=0.40.0" }, { name = "boto3", specifier = ">=1.41.1" }, { name = "fastapi", specifier = ">=0.115.0" },