Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ lint:
uv run ruff check .
uv run basedpyright

coverage:
rm -rf .coverage/*
uv run pytest --cov --cov-report=html
open .coverage/htmlcov/index.html

install-deps:
uv sync

Expand Down
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ classifiers = [
requires-python = ">=3.13"
dependencies = [
"inflection>=0.5.1",
"psycopg>=3.2.12",
"psycopg-pool>=3.2.7",
"psycopg>=3.3.2",
"psycopg-pool>=3.3.0",
"pydantic>=2.12.4",
]

Expand All @@ -36,9 +36,11 @@ build-backend = "uv_build"
[dependency-groups]
dev = [
"basedpyright>=1.31.7",
"psycopg[binary]>=3.3.2",
"pytest>=8.4.2",
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"pytest-randomly>=4.0.1",
"ruff>=0.14.1",
"testcontainers>=4",
]
Expand Down Expand Up @@ -105,20 +107,19 @@ ban-relative-imports = "all"
testpaths = ["tests"]
filterwarnings = [
"error",
"ignore:The @wait_container_is_ready decorator is deprecated:DeprecationWarning",
"ignore:.*wait_container_is_ready:DeprecationWarning",
]
addopts = ["--no-cov-on-fail", "--cov-report=term-missing:skip-covered"]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"


[tool.coverage.run]
branch = true
omit = ["*_test.py"]
omit = ["test_*.py", "testdb.py"]
data_file = ".coverage/db.sqlite"

[tool.coverage.html]
directory = ".coverage/htmlcov"

[tool.coverage.report]
exclude_also = ["if TYPE_CHECKING:", "@overload"]
exclude_also = ["@overload"]
116 changes: 108 additions & 8 deletions src/iron_sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from iron_sql.sqlc import Catalog
from iron_sql.sqlc import Column
from iron_sql.sqlc import Enum
from iron_sql.sqlc import Query
from iron_sql.sqlc import SQLCResult
from iron_sql.sqlc import run_sqlc

logger = logging.getLogger(__name__)
Expand All @@ -31,13 +33,38 @@ class ColumnPySpec:
py_type: str


def generate_sql_package( # noqa: PLR0914
def _collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
used = set()
catalog = sqlc_res.catalog
default_schema = catalog.default_schema

def check_column(col: Column) -> None:
schema_name = col.type.schema_name or default_schema
if schema_name and catalog.schema_by_name(schema_name).has_enum(col.type.name):
used.add((schema_name, col.type.name))

for q in sqlc_res.queries:
for c in q.columns:
check_column(c)
for p in q.params:
check_column(p.column)

for schema_name in sqlc_res.used_schemas():
for table in catalog.schema_by_name(schema_name).tables:
for c in table.columns:
check_column(c)

return used


def generate_sql_package( # noqa: PLR0913, PLR0914
*,
schema_path: Path,
package_full_name: str,
dsn_import: str,
application_name: str | None = None,
to_pascal_fn=alias_generators.to_pascal,
to_snake_fn=alias_generators.to_snake,
debug_path: Path | None = None,
src_path: Path = Path(),
sqlc_path: Path | None = None,
Expand All @@ -54,6 +81,8 @@ def generate_sql_package( # noqa: PLR0914
application_name: Optional application name for connection pool
to_pascal_fn: Function to convert names to PascalCase (default:
pydantic's to_pascal)
to_snake_fn: Function to convert names to snake_case (default:
pydantic's to_snake)
debug_path: Optional path to save sqlc inputs for inspection
src_path: Base source path for scanning queries (default: Path())
sqlc_path: Optional path to sqlc binary if not in PATH
Expand Down Expand Up @@ -100,14 +129,30 @@ def generate_sql_package( # noqa: PLR0914

entities = [render_entity(e.name, e.column_specs) for e in ordered_entities]

used_enums = _collect_used_enums(sqlc_res)

enums = [
render_enum_class(e, package_name, to_pascal_fn, to_snake_fn)
for schema in sqlc_res.catalog.schemas
for e in schema.enums
if (schema.name, e.name) in used_enums
]

query_classes = [
render_query_class(
q.name,
q.text,
package_name,
[
(
column_py_spec(p.column, sqlc_res.catalog, p.number),
column_py_spec(
p.column,
sqlc_res.catalog,
package_name,
to_pascal_fn,
to_snake_fn,
p.number,
),
p.column.is_named_param,
)
for p in q.params
Expand All @@ -130,6 +175,7 @@ def generate_sql_package( # noqa: PLR0914
package_name,
sql_fn_name,
sorted(entities),
sorted(enums),
sorted(query_classes),
sorted(query_overloads),
sorted(query_dict_entries),
Expand All @@ -147,6 +193,7 @@ def render_package(
package_name: str,
sql_fn_name: str,
entities: list[str],
enums: list[str],
query_classes: list[str],
query_overloads: list[str],
query_dict_entries: list[str],
Expand Down Expand Up @@ -181,6 +228,7 @@ def render_package(
from contextlib import asynccontextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from enum import StrEnum
from typing import Literal
from typing import overload

Expand Down Expand Up @@ -218,6 +266,9 @@ async def {package_name}_transaction() -> AsyncIterator[None]:
yield


{"\n\n\n".join(enums)}


{"\n\n\n".join(entities)}


Expand All @@ -229,7 +280,7 @@ class Query:


_QUERIES: dict[str, type[Query]] = {{
{("," + chr(10) + " ").join(query_dict_entries)},
{(",\n ").join(query_dict_entries)}
}}


Expand All @@ -247,6 +298,37 @@ def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
""".strip()


def render_enum_class(
enum: Enum,
package_name: str,
to_pascal_fn: Callable[[str], str],
to_snake_fn: Callable[[str], str],
) -> str:
class_name = to_pascal_fn(f"{package_name}_{enum.name}")
members = []
seen_names: dict[str, int] = {}

for val in enum.vals:
name = to_snake_fn(val).upper()
name = "".join(c if c.isalnum() else "_" for c in name)
name = name.strip("_") or "EMPTY"
if name[0].isdigit():
name = "_" + name
if name in seen_names:
seen_names[name] += 1
name = f"{name}_{seen_names[name]}"
else:
seen_names[name] = 1
members.append(f'{name} = "{val}"')

return f"""

class {class_name}(StrEnum):
{indent_block("\n".join(members), " ")}

""".strip()


def render_entity(
name: str,
columns: tuple[ColumnPySpec, ...],
Expand Down Expand Up @@ -418,6 +500,7 @@ class SQLEntity:
columns: list[Column]
catalog: Catalog = dataclasses.field(repr=False)
to_pascal_fn: Callable[[str], str]
to_snake_fn: Callable[[str], str] = inflection.underscore

@property
def name(self) -> str:
Expand All @@ -433,7 +516,12 @@ def name(self) -> str:

@property
def column_specs(self) -> tuple[ColumnPySpec, ...]:
return tuple(column_py_spec(c, self.catalog) for c in self.columns)
return tuple(
column_py_spec(
c, self.catalog, self.package_name, self.to_pascal_fn, self.to_snake_fn
)
for c in self.columns
)


def map_entities(
Expand All @@ -443,6 +531,7 @@ def map_entities(
used_schemas: list[str],
queries_from_code: list[CodeQuery],
to_pascal_fn: Callable[[str], str],
to_snake_fn: Callable[[str], str] = inflection.underscore,
):
row_types = {q.name: q.row_type for q in queries_from_code}

Expand All @@ -454,6 +543,7 @@ def map_entities(
columns=t.columns,
catalog=catalog,
to_pascal_fn=to_pascal_fn,
to_snake_fn=to_snake_fn,
)
for sch in used_schemas
for t in catalog.schema_by_name(sch).tables
Expand All @@ -476,6 +566,7 @@ def map_entities(
columns=q.columns,
catalog=catalog,
to_pascal_fn=to_pascal_fn,
to_snake_fn=to_snake_fn,
)
for q in queries_from_sqlc
if len(q.columns) > 1
Expand All @@ -495,7 +586,9 @@ def map_entities(
if len(q.columns) == 0:
result_types[q.name] = "None"
elif len(q.columns) == 1:
result_types[q.name] = column_py_spec(q.columns[0], catalog).py_type
result_types[q.name] = column_py_spec(
q.columns[0], catalog, package_name, to_pascal_fn, to_snake_fn
).py_type
else:
column_spec = query_result_entities[q.name].column_specs
result_types[q.name] = unique_entities[column_spec].name
Expand All @@ -504,7 +597,12 @@ def map_entities(


def column_py_spec( # noqa: C901, PLR0912
column: Column, catalog: Catalog, number: int = 0
column: Column,
catalog: Catalog,
package_name: str,
to_pascal_fn: Callable[[str], str],
_to_snake_fn: Callable[[str], str] = inflection.underscore,
number: int = 0,
) -> ColumnPySpec:
db_type = column.type.name.removeprefix("pg_catalog.")
match db_type:
Expand Down Expand Up @@ -541,8 +639,10 @@ def column_py_spec( # noqa: C901, PLR0912
py_type = "uuid.UUID"
case "any" | "anyelement":
py_type = "object"
case enum if catalog.schema_by_ref(column.table).has_enum(enum):
py_type = "str"
case enum if (
sch := column.type.schema_name or catalog.default_schema
) and catalog.schema_by_name(sch).has_enum(enum):
py_type = to_pascal_fn(f"{package_name}_{enum}") if package_name else "str"
case _:
logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})")
py_type = "object"
Expand Down
3 changes: 3 additions & 0 deletions src/iron_sql/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Sequence
from contextlib import asynccontextmanager
from contextvars import ContextVar
from enum import Enum
from typing import Any
from typing import Literal
from typing import Self
Expand Down Expand Up @@ -123,6 +124,8 @@ def typed_scalar_row__(values: Sequence[Any]) -> T | None:
if not not_null and val is None:
return None
if not isinstance(val, typ):
if issubclass(typ, Enum):
return typ(val)
msg = f"Expected scalar of type {typ}, got {type(val)}"
raise TypeError(msg)
return val
Expand Down
7 changes: 5 additions & 2 deletions src/iron_sql/sqlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,19 @@ class SQLCResult(pydantic.BaseModel):
queries: list[Query]

def used_schemas(self) -> list[str]:
result = {
table_schemas = {
c.table.schema_name
for q in self.queries
for c in q.columns
if c.table is not None
}
type_schemas = {c.type.schema_name for q in self.queries for c in q.columns}
result = {*table_schemas, *type_schemas}
if "" in result:
result.remove("")
result.add(self.catalog.default_schema)
return list(result)
catalog_schema_names = {s.name for s in self.catalog.schemas}
return [s for s in result if s in catalog_schema_names]


def _resolve_sqlc_command(
Expand Down
Loading