diff --git a/pyproject.toml b/pyproject.toml index 5a2b6ab..f1e3b16 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,6 @@ ignore = [ "RUF003", "TC006", "TD", - "TRY300", ] [tool.ruff.lint.per-file-ignores] @@ -104,18 +103,23 @@ max-args = 10 ban-relative-imports = "all" [tool.pytest.ini_options] +strict = true testpaths = ["tests"] filterwarnings = [ "error", "ignore:.*wait_container_is_ready:DeprecationWarning", ] -addopts = ["--no-cov-on-fail", "--cov-report=term-missing:skip-covered"] +addopts = [ + "--import-mode=importlib", + "--no-cov-on-fail", + "--cov-report=term-missing:skip-covered", +] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" [tool.coverage.run] branch = true -omit = ["test_*.py", "testdb.py"] +omit = ["tests/*.py", "testdb.py"] data_file = ".coverage/db.sqlite" [tool.coverage.html] diff --git a/src/iron_sql/generator.py b/src/iron_sql/generator.py index 5ead545..f320b94 100644 --- a/src/iron_sql/generator.py +++ b/src/iron_sql/generator.py @@ -34,27 +34,21 @@ class ColumnPySpec: def _collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]: - used = set() - catalog = sqlc_res.catalog - default_schema = catalog.default_schema - - def check_column(col: Column) -> None: - schema_name = col.type.schema_name or default_schema - if schema_name and catalog.schema_by_name(schema_name).has_enum(col.type.name): - used.add((schema_name, col.type.name)) - - for q in sqlc_res.queries: - for c in q.columns: - check_column(c) - for p in q.params: - check_column(p.column) - - for schema_name in sqlc_res.used_schemas(): - for table in catalog.schema_by_name(schema_name).tables: - for c in table.columns: - check_column(c) - - return used + return { + (schema.name, col.type.name) + for col in ( + *(c for q in sqlc_res.queries for c in q.columns), + *(p.column for q in sqlc_res.queries for p in q.params), + *( + c + for schema_name in sqlc_res.used_schemas() + for table in sqlc_res.catalog.schema_by_name(schema_name).tables + for c in table.columns + ), + ) + for schema in (sqlc_res.catalog.schema_by_ref(col.type),) + if schema.has_enum(col.type.name) + } def generate_sql_package( # noqa: PLR0913, PLR0914 @@ -639,9 +633,7 @@ def column_py_spec( # noqa: C901, PLR0912 py_type = "uuid.UUID" case "any" | "anyelement": py_type = "object" - case enum if ( - sch := column.type.schema_name or catalog.default_schema - ) and catalog.schema_by_name(sch).has_enum(enum): + case enum if catalog.schema_by_ref(column.type).has_enum(enum): py_type = to_pascal_fn(f"{package_name}_{enum}") if package_name else "str" case _: logger.warning(f"Unknown SQL type: {column.type.name} ({column.name})") diff --git a/src/iron_sql/sqlc.py b/src/iron_sql/sqlc.py index c88c9f3..786f371 100644 --- a/src/iron_sql/sqlc.py +++ b/src/iron_sql/sqlc.py @@ -74,11 +74,8 @@ def schema_by_name(self, name: str) -> Schema: msg = f"Schema not found: {name}" raise ValueError(msg) - def schema_by_ref(self, ref: CatalogReference | None) -> Schema: - schema = self.default_schema - if ref and ref.schema_name: - schema = ref.schema_name - return self.schema_by_name(schema) + def schema_by_ref(self, ref: CatalogReference) -> Schema: + return self.schema_by_name(ref.schema_name or self.default_schema) class QueryParameter(pydantic.BaseModel): diff --git a/tests/test_code_generation.py b/tests/test_code_generation.py index 9ed47a7..2405d46 100644 --- a/tests/test_code_generation.py +++ b/tests/test_code_generation.py @@ -77,3 +77,42 @@ def test_unsupported_param_types_array(test_project: ProjectBuilder) -> None: ) with pytest.raises(TypeError, match=r"Unsupported column type: jsonb\[\]"): test_project.generate_no_import() + + +def test_generator_is_idempotent(test_project: ProjectBuilder) -> None: + assert test_project.generate_no_import() is True + assert test_project.generate_no_import() is False + + +def test_generator_valid_explicit_row_type(test_project: ProjectBuilder) -> None: + test_project.set_queries_source( + """ + from typing import Any + def testdb_sql(q: str, **kwargs: Any) -> Any: ... + + RT = "UserMini" + q = testdb_sql("SELECT id, username FROM users", row_type="UserMini") + """ + ) + assert test_project.generate_no_import() is True + + +async def test_special_types_params(test_project: ProjectBuilder) -> None: + await test_project.extend_schema( + """ + CREATE TABLE special_types ( + id uuid PRIMARY KEY, + d date NOT NULL, + t time NOT NULL, + ts timestamp NOT NULL, + b boolean NOT NULL, + j jsonb + ); + """ + ) + test_project.add_query( + "insert_special", + "INSERT INTO special_types (id, d, t, ts, b, j) " + "VALUES ($1, $2, $3, $4, $5, $6)", + ) + assert test_project.generate_no_import() is True diff --git a/tests/test_runtime_coverage.py b/tests/test_runtime_coverage.py new file mode 100644 index 0000000..462e772 --- /dev/null +++ b/tests/test_runtime_coverage.py @@ -0,0 +1,40 @@ +from collections.abc import AsyncIterator + +import pytest + +from iron_sql.runtime import ConnectionPool +from iron_sql.runtime import TooManyRowsError +from iron_sql.runtime import get_one_row_or_none +from iron_sql.runtime import typed_scalar_row + + +@pytest.fixture +async def async_pool(pg_dsn: str) -> AsyncIterator[ConnectionPool]: + p = ConnectionPool(pg_dsn, name="test_pool") + yield p + await p.close() + + +async def test_pool_check_and_await(async_pool: ConnectionPool) -> None: + await async_pool.check() + await async_pool.await_connections() + + +async def test_pool_context_manager(pg_dsn: str) -> None: + async with ConnectionPool(pg_dsn) as p: + await p.check() + + +def test_get_one_row_or_none_too_many() -> None: + with pytest.raises(TooManyRowsError): + get_one_row_or_none([1, 2]) + + +async def test_typed_scalar_row_type_mismatch(async_pool: ConnectionPool) -> None: + async with ( + async_pool.connection() as conn, + conn.cursor(row_factory=typed_scalar_row(int, not_null=True)) as cur, + ): + await cur.execute("SELECT 'not an int'::text") + with pytest.raises(TypeError, match="Expected scalar of type "): + await cur.fetchone() diff --git a/tests/test_sqlc_config.py b/tests/test_sqlc_config.py new file mode 100644 index 0000000..89a9eb5 --- /dev/null +++ b/tests/test_sqlc_config.py @@ -0,0 +1,89 @@ +import os +from pathlib import Path + +import pytest + +from iron_sql.sqlc import run_sqlc + + +def test_run_sqlc_exclusive_args(tmp_path: Path) -> None: + schema = tmp_path / "schema.sql" + schema.touch() + with pytest.raises( + ValueError, match="sqlc_command and sqlc_path are mutually exclusive" + ): + run_sqlc( + schema_path=schema, + queries=[("q", "SELECT 1")], + dsn=None, + sqlc_path=Path("/bin/sqlc"), + sqlc_command=["docker", "run"], + ) + + +def test_run_sqlc_empty_command(tmp_path: Path) -> None: + schema = tmp_path / "schema.sql" + schema.touch() + with pytest.raises(ValueError, match="sqlc_command must not be empty"): + run_sqlc( + schema_path=schema, + queries=[("q", "SELECT 1")], + dsn=None, + sqlc_path=None, + sqlc_command=[], + ) + + +def test_run_sqlc_not_found_in_path(tmp_path: Path) -> None: + schema = tmp_path / "schema.sql" + schema.touch() + original_path = os.environ.get("PATH", "") + os.environ["PATH"] = "" + try: + with pytest.raises(FileNotFoundError, match="sqlc not found in PATH"): + run_sqlc( + schema_path=schema, + queries=[("q", "SELECT 1")], + dsn=None, + sqlc_path=None, + sqlc_command=None, + ) + finally: + os.environ["PATH"] = original_path + + +def test_run_sqlc_explicit_path_not_exists(tmp_path: Path) -> None: + schema = tmp_path / "schema.sql" + schema.touch() + with pytest.raises(FileNotFoundError, match="sqlc not found at /does/not/exist"): + run_sqlc( + schema_path=schema, + queries=[("q", "SELECT 1")], + dsn=None, + sqlc_path=Path("/does/not/exist"), + sqlc_command=None, + ) + + +def test_run_sqlc_missing_schema() -> None: + with pytest.raises(ValueError, match="Schema file not found"): + run_sqlc( + schema_path=Path("nonexistent.sql"), + queries=[], + dsn="postgres://", + ) + + +def test_run_sqlc_no_queries() -> None: + schema_path = Path("schema.sql") + schema_path.touch() + try: + result = run_sqlc( + schema_path=schema_path, + queries=[], + dsn="postgres://", + ) + assert result.queries == [] + assert result.catalog.schemas == [] + finally: + schema_path.unlink()