Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions src/iron_sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ def _collect_used_enums(sqlc_res: SQLCResult) -> set[tuple[str, str]]:
for col in (
*(c for q in sqlc_res.queries for c in q.columns),
*(p.column for q in sqlc_res.queries for p in q.params),
*(
c
for schema_name in sqlc_res.used_schemas()
for table in sqlc_res.catalog.schema_by_name(schema_name).tables
for c in table.columns
),
)
for schema in (sqlc_res.catalog.schema_by_ref(col.type),)
if schema.has_enum(col.type.name)
Expand Down
4 changes: 1 addition & 3 deletions src/iron_sql/sqlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,12 @@ class SQLCResult(pydantic.BaseModel):
queries: list[Query]

def used_schemas(self) -> list[str]:
table_schemas = {
result = {
c.table.schema_name
for q in self.queries
for c in q.columns
if c.table is not None
}
type_schemas = {c.type.schema_name for q in self.queries for c in q.columns}
result = {*table_schemas, *type_schemas}
if "" in result:
result.remove("")
result.add(self.catalog.default_schema)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_type_system.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import keyword
from enum import StrEnum

import pytest

from tests.conftest import ProjectBuilder


Expand Down Expand Up @@ -176,3 +178,33 @@ async def test_pg_catalog_type_does_not_break_generation(

row = await mod.testdb_sql(sql).query_single_row()
assert row == 1


def test_pg_catalog_does_not_trigger_warnings(
test_project: ProjectBuilder, caplog: pytest.LogCaptureFixture
) -> None:
test_project.add_query("get_user", "SELECT * FROM users")

test_project.generate()

assert "Unknown SQL type" not in caplog.text


async def test_table_column_enum_not_in_query_is_skipped(
test_project: ProjectBuilder,
) -> None:
extra_schema = """
CREATE TYPE table_only_status AS ENUM ('pending', 'processed');
CREATE TABLE status_log (
id SERIAL PRIMARY KEY,
status table_only_status NOT NULL
);
"""

await test_project.extend_schema(extra_schema)

test_project.add_query("get_users", "SELECT * FROM users")

mod = test_project.generate()

assert not hasattr(mod, "TestdbTableOnlyStatus")