Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ ignore = [
"RUF003",
"TC006",
"TD",
"TRY300",
]

[tool.ruff.lint.per-file-ignores]
Expand All @@ -104,18 +103,23 @@ max-args = 10
ban-relative-imports = "all"

[tool.pytest.ini_options]
strict = true
testpaths = ["tests"]
filterwarnings = [
"error",
"ignore:.*wait_container_is_ready:DeprecationWarning",
]
addopts = ["--no-cov-on-fail", "--cov-report=term-missing:skip-covered"]
addopts = [
"--import-mode=importlib",
"--no-cov-on-fail",
"--cov-report=term-missing:skip-covered",
]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"

[tool.coverage.run]
branch = true
omit = ["test_*.py", "testdb.py"]
omit = ["tests/*.py", "testdb.py"]
data_file = ".coverage/db.sqlite"

[tool.coverage.html]
Expand Down
40 changes: 16 additions & 24 deletions src/iron_sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,21 @@ class ColumnPySpec:


def _collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
used = set()
catalog = sqlc_res.catalog
default_schema = catalog.default_schema

def check_column(col: Column) -> None:
schema_name = col.type.schema_name or default_schema
if schema_name and catalog.schema_by_name(schema_name).has_enum(col.type.name):
used.add((schema_name, col.type.name))

for q in sqlc_res.queries:
for c in q.columns:
check_column(c)
for p in q.params:
check_column(p.column)

for schema_name in sqlc_res.used_schemas():
for table in catalog.schema_by_name(schema_name).tables:
for c in table.columns:
check_column(c)

return used
return {
(schema.name, col.type.name)
for col in (
*(c for q in sqlc_res.queries for c in q.columns),
*(p.column for q in sqlc_res.queries for p in q.params),
*(
c
for schema_name in sqlc_res.used_schemas()
for table in sqlc_res.catalog.schema_by_name(schema_name).tables
for c in table.columns
),
)
for schema in (sqlc_res.catalog.schema_by_ref(col.type),)
if schema.has_enum(col.type.name)
}


def generate_sql_package( # noqa: PLR0913, PLR0914
Expand Down Expand Up @@ -639,9 +633,7 @@ def column_py_spec( # noqa: C901, PLR0912
py_type = "uuid.UUID"
case "any" | "anyelement":
py_type = "object"
case enum if (
sch := column.type.schema_name or catalog.default_schema
) and catalog.schema_by_name(sch).has_enum(enum):
case enum if catalog.schema_by_ref(column.type).has_enum(enum):
py_type = to_pascal_fn(f"{package_name}_{enum}") if package_name else "str"
case _:
logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})")
Expand Down
7 changes: 2 additions & 5 deletions src/iron_sql/sqlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ def schema_by_name(self, name: str) -> Schema:
msg = f"Schema not found: {name}"
raise ValueError(msg)

def schema_by_ref(self, ref: CatalogReference | None) -> Schema:
schema = self.default_schema
if ref and ref.schema_name:
schema = ref.schema_name
return self.schema_by_name(schema)
def schema_by_ref(self, ref: CatalogReference) -> Schema:
return self.schema_by_name(ref.schema_name or self.default_schema)


class QueryParameter(pydantic.BaseModel):
Expand Down
39 changes: 39 additions & 0 deletions tests/test_code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,42 @@ def test_unsupported_param_types_array(test_project: ProjectBuilder) -> None:
)
with pytest.raises(TypeError, match=r"Unsupported column type: jsonb\[\]"):
test_project.generate_no_import()


def test_generator_is_idempotent(test_project: ProjectBuilder) -> None:
assert test_project.generate_no_import() is True
assert test_project.generate_no_import() is False


def test_generator_valid_explicit_row_type(test_project: ProjectBuilder) -> None:
test_project.set_queries_source(
"""
from typing import Any
def testdb_sql(q: str, **kwargs: Any) -> Any: ...

RT = "UserMini"
q = testdb_sql("SELECT id, username FROM users", row_type="UserMini")
"""
)
assert test_project.generate_no_import() is True


async def test_special_types_params(test_project: ProjectBuilder) -> None:
await test_project.extend_schema(
"""
CREATE TABLE special_types (
id uuid PRIMARY KEY,
d date NOT NULL,
t time NOT NULL,
ts timestamp NOT NULL,
b boolean NOT NULL,
j jsonb
);
"""
)
test_project.add_query(
"insert_special",
"INSERT INTO special_types (id, d, t, ts, b, j) "
"VALUES ($1, $2, $3, $4, $5, $6)",
)
assert test_project.generate_no_import() is True
40 changes: 40 additions & 0 deletions tests/test_runtime_coverage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from collections.abc import AsyncIterator

import pytest

from iron_sql.runtime import ConnectionPool
from iron_sql.runtime import TooManyRowsError
from iron_sql.runtime import get_one_row_or_none
from iron_sql.runtime import typed_scalar_row


@pytest.fixture
async def async_pool(pg_dsn: str) -> AsyncIterator[ConnectionPool]:
p = ConnectionPool(pg_dsn, name="test_pool")
yield p
await p.close()


async def test_pool_check_and_await(async_pool: ConnectionPool) -> None:
await async_pool.check()
await async_pool.await_connections()


async def test_pool_context_manager(pg_dsn: str) -> None:
async with ConnectionPool(pg_dsn) as p:
await p.check()


def test_get_one_row_or_none_too_many() -> None:
with pytest.raises(TooManyRowsError):
get_one_row_or_none([1, 2])


async def test_typed_scalar_row_type_mismatch(async_pool: ConnectionPool) -> None:
async with (
async_pool.connection() as conn,
conn.cursor(row_factory=typed_scalar_row(int, not_null=True)) as cur,
):
await cur.execute("SELECT 'not an int'::text")
with pytest.raises(TypeError, match="Expected scalar of type <class 'int'>"):
await cur.fetchone()
89 changes: 89 additions & 0 deletions tests/test_sqlc_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
from pathlib import Path

import pytest

from iron_sql.sqlc import run_sqlc


def test_run_sqlc_exclusive_args(tmp_path: Path) -> None:
schema = tmp_path / "schema.sql"
schema.touch()
with pytest.raises(
ValueError, match="sqlc_command and sqlc_path are mutually exclusive"
):
run_sqlc(
schema_path=schema,
queries=[("q", "SELECT 1")],
dsn=None,
sqlc_path=Path("/bin/sqlc"),
sqlc_command=["docker", "run"],
)


def test_run_sqlc_empty_command(tmp_path: Path) -> None:
schema = tmp_path / "schema.sql"
schema.touch()
with pytest.raises(ValueError, match="sqlc_command must not be empty"):
run_sqlc(
schema_path=schema,
queries=[("q", "SELECT 1")],
dsn=None,
sqlc_path=None,
sqlc_command=[],
)


def test_run_sqlc_not_found_in_path(tmp_path: Path) -> None:
schema = tmp_path / "schema.sql"
schema.touch()
original_path = os.environ.get("PATH", "")
os.environ["PATH"] = ""
try:
with pytest.raises(FileNotFoundError, match="sqlc not found in PATH"):
run_sqlc(
schema_path=schema,
queries=[("q", "SELECT 1")],
dsn=None,
sqlc_path=None,
sqlc_command=None,
)
finally:
os.environ["PATH"] = original_path


def test_run_sqlc_explicit_path_not_exists(tmp_path: Path) -> None:
schema = tmp_path / "schema.sql"
schema.touch()
with pytest.raises(FileNotFoundError, match="sqlc not found at /does/not/exist"):
run_sqlc(
schema_path=schema,
queries=[("q", "SELECT 1")],
dsn=None,
sqlc_path=Path("/does/not/exist"),
sqlc_command=None,
)


def test_run_sqlc_missing_schema() -> None:
with pytest.raises(ValueError, match="Schema file not found"):
run_sqlc(
schema_path=Path("nonexistent.sql"),
queries=[],
dsn="postgres://",
)


def test_run_sqlc_no_queries() -> None:
schema_path = Path("schema.sql")
schema_path.touch()
try:
result = run_sqlc(
schema_path=schema_path,
queries=[],
dsn="postgres://",
)
assert result.queries == []
assert result.catalog.schemas == []
finally:
schema_path.unlink()