Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ iron_sql keeps SQL close to Python call sites while giving you typed, async quer
- Safe-by-default: helper methods enforce expected row counts instead of returning silent `None`.

## Quick start
1. Install `iron_sql`, `psycopg`, `psycopg-pool`, `orjson`, and `pydantic`.
2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure `/usr/local/bin/sqlc` is in PATH.
3. Add a Postgres schema dump, for example `db/adept_schema.sql`.
4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code, runs `sqlc`, and writes a module such as `adept/db/adept.py`.
1. Install `iron_sql`, `psycopg`, `psycopg-pool`, and `pydantic`.
2. Install [`sqlc` v2](https://docs.sqlc.dev/en/latest/overview/install.html) and ensure it is available in your PATH.
3. Add a Postgres schema dump, for example `db/mydatabase_schema.sql`.
4. Call `generate_sql_package(schema_path=..., package_full_name=..., dsn_import=...)` from a small script or task. The generator scans your code (defaults to current directory), runs `sqlc`, and writes a module such as `myapp/db/mydatabase.py`.

## Authoring queries
- Use the package helper for your DB, e.g. `adept_sql("select ...")`. The SQL string must be a literal so the generator can find it.
- Use the package helper for your DB, e.g. `mydatabase_sql("select ...")`. The SQL string must be a literal so the generator can find it.
- Named parameters:
- Required: `@param`
- Optional: `@param?` (expands to `sqlc.narg('param')`)
Expand All @@ -31,7 +31,10 @@ iron_sql keeps SQL close to Python call sites while giving you typed, async quer

## Adding another database package
Provide the schema file and DSN import string, then call `generate_sql_package()` with:
- `schema_path`: path to the schema SQL file.
- `package_full_name`: target module, e.g. `adept.db.analytics`.
- `dsn_import`: import path to a DSN string, e.g. `adept.config:CONFIG.analytics_db_url.value`.
- `schema_path`: path to the schema SQL file (relative to `src_path`).
- `package_full_name`: target module, e.g. `myapp.db`.
- `dsn_import`: import path to a DSN string, e.g. `myapp.config:CONFIG.db_url.value`.
- `src_path`: optional base source path for scanning queries (defaults current directory).
- `sqlc_path`: optional path to the sqlc binary if not in PATH (e.g., `Path("/custom/bin/sqlc")`).
- `tempdir_path`: optional path for temporary file generation (useful for Docker mounts).
- Optional `application_name`, `debug_path`, and `to_pascal_fn` if you need naming overrides or want to keep `sqlc` inputs for inspection.
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dev = [
"pytest-asyncio>=1.2.0",
"pytest-cov>=7.0.0",
"ruff>=0.14.1",
"testcontainers>=4",
]

[tool.pyright]
Expand Down Expand Up @@ -85,10 +86,11 @@ ignore = [
"RUF003",
"TC006",
"TD",
"TRY300",
]

[tool.ruff.lint.per-file-ignores]
"{*_test.py,conftest.py}" = ["A002", "PLR2004", "S", "FBT"]
"test_*.py" = ["A002", "PLR2004", "S", "FBT"]

[tool.ruff.lint.isort]
force-single-line = true
Expand All @@ -100,11 +102,16 @@ max-args = 10
ban-relative-imports = "all"

[tool.pytest.ini_options]
filterwarnings = ["error"]
testpaths = ["tests"]
filterwarnings = [
"error",
"ignore:The @wait_container_is_ready decorator is deprecated:DeprecationWarning",
]
addopts = ["--no-cov-on-fail", "--cov-report=term-missing:skip-covered"]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function"


[tool.coverage.run]
branch = true
omit = ["*_test.py"]
Expand Down
2 changes: 1 addition & 1 deletion src/iron_sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""iron_gql: Typed GraphQL client generator for Python."""
"""iron_sql: Typed SQL client generator for Python."""

from iron_sql.generator import generate_sql_package

Expand Down
70 changes: 53 additions & 17 deletions src/iron_sql/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,30 @@ def generate_sql_package( # noqa: PLR0914
application_name: str | None = None,
to_pascal_fn=alias_generators.to_pascal,
debug_path: Path | None = None,
src_path: Path,
src_path: Path = Path(),
sqlc_path: Path | None = None,
tempdir_path: Path | None = None,
sqlc_command: list[str] | None = None,
) -> bool:
"""Generate a typed SQL package from schema and queries.

Args:
schema_path: Path to the Postgres schema SQL file (relative to src_path)
package_full_name: Target module name (e.g., "myapp.mydatabase")
dsn_import: Import path to DSN string (e.g.,
"myapp.config:CONFIG.db_url")
application_name: Optional application name for connection pool
to_pascal_fn: Function to convert names to PascalCase (default:
pydantic's to_pascal)
debug_path: Optional path to save sqlc inputs for inspection
src_path: Base source path for scanning queries (default: Path())
sqlc_path: Optional path to sqlc binary if not in PATH
tempdir_path: Optional path for temporary file generation
sqlc_command: Optional command prefix to run sqlc

Returns:
True if the package was generated or modified, False otherwise
"""
dsn_import_package, dsn_import_path = dsn_import.split(":")

package_name = package_full_name.split(".")[-1] # noqa: PLC0207
Expand All @@ -58,6 +80,9 @@ def generate_sql_package( # noqa: PLR0914
[(q.name, q.stmt) for q in queries],
dsn=dsn,
debug_path=debug_path,
sqlc_path=sqlc_path,
tempdir_path=tempdir_path,
sqlc_command=sqlc_command,
)

if sqlc_res.error:
Expand Down Expand Up @@ -97,7 +122,7 @@ def generate_sql_package( # noqa: PLR0914
render_query_overload(sql_fn_name, q.name, q.stmt, q.row_type) for q in queries
]

query_cases = [render_query_case(q.name, q.stmt) for q in queries]
query_dict_entries = [render_query_dict_entry(q.name, q.stmt) for q in queries]

new_content = render_package(
dsn_import_package,
Expand All @@ -107,7 +132,7 @@ def generate_sql_package( # noqa: PLR0914
sorted(entities),
sorted(query_classes),
sorted(query_overloads),
sorted(query_cases),
sorted(query_dict_entries),
application_name,
)
changed = write_if_changed(target_package_path, new_content + "\n")
Expand All @@ -124,7 +149,7 @@ def render_package(
entities: list[str],
query_classes: list[str],
query_overloads: list[str],
query_cases: list[str],
query_dict_entries: list[str],
application_name: str | None = None,
):
return f"""
Expand Down Expand Up @@ -170,7 +195,7 @@ def render_package(
{package_name.upper()}_POOL = runtime.ConnectionPool(
{dsn_import_path},
name="{package_name}",
application_name="{application_name}",
application_name={application_name!r},
)

_{package_name}_connection = ContextVar[psycopg.AsyncConnection | None](
Expand Down Expand Up @@ -203,14 +228,21 @@ class Query:
{"\n\n\n".join(query_classes)}


_QUERIES: dict[str, type[Query]] = {{
{("," + chr(10) + " ").join(query_dict_entries)},
}}


{"\n".join(query_overloads)}
@overload
def {sql_fn_name}(stmt: str) -> Query: ...


def {sql_fn_name}(stmt: str, row_type: str | None = None) -> Query:
{indent_block("\n".join(query_cases), " ")}
return Query()
if stmt in _QUERIES:
return _QUERIES[stmt]()
msg = f"Unknown statement: {{stmt!r}}"
raise KeyError(msg)

""".strip()

Expand Down Expand Up @@ -308,11 +340,11 @@ async def query_all_rows({", ".join(query_fn_params)}) -> list[{result}]:

async def query_single_row({", ".join(query_fn_params)}) -> {result}:
async with self._execute({params_arg}) as cur:
return runtime.get_one_row(await cur.fetchall())
return runtime.get_one_row(await cur.fetchmany(2))

async def query_optional_row({", ".join(query_fn_params)}) -> {base_result} | None:
async with self._execute({params_arg}) as cur:
return runtime.get_one_row_or_none(await cur.fetchall())
return runtime.get_one_row_or_none(await cur.fetchmany(2))

""".strip()
else:
Expand Down Expand Up @@ -357,13 +389,8 @@ def {sql_fn_name}(stmt: Literal[{stmt!r}]{result_arg}) -> {query_name}: ...
""".strip()


def render_query_case(query_name: str, stmt: str) -> str:
return f"""

if stmt == {stmt!r}:
return {query_name}()

""".strip()
def render_query_dict_entry(query_name: str, stmt: str) -> str:
return f"{stmt!r}: {query_name}"


@dataclass(kw_only=True)
Expand Down Expand Up @@ -483,7 +510,16 @@ def column_py_spec( # noqa: C901, PLR0912
match db_type:
case "bool" | "boolean":
py_type = "bool"
case "int2" | "int4" | "int8" | "smallint" | "integer" | "bigint":
case (
"int2"
| "int4"
| "int8"
| "smallint"
| "integer"
| "bigint"
| "serial"
| "bigserial"
):
py_type = "int"
case "float4" | "float8":
py_type = "float"
Expand Down
2 changes: 0 additions & 2 deletions src/iron_sql/integration_test.py

This file was deleted.

43 changes: 36 additions & 7 deletions src/iron_sql/sqlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@

import pydantic

SQLC_QUERY_TPL = """
-- name: ${name} :exec
${stmt};
"""


class CatalogReference(pydantic.BaseModel):
catalog: str
Expand Down Expand Up @@ -117,12 +112,41 @@ def used_schemas(self) -> list[str]:
return list(result)


def _resolve_sqlc_command(
sqlc_path: Path | None,
sqlc_command: list[str] | None,
) -> list[str]:
if sqlc_command is not None:
if sqlc_path is not None:
msg = "sqlc_command and sqlc_path are mutually exclusive"
raise ValueError(msg)
if not sqlc_command:
msg = "sqlc_command must not be empty"
raise ValueError(msg)
return sqlc_command

if sqlc_path is None:
discovered_path = shutil.which("sqlc")
if discovered_path is None:
msg = "sqlc not found in PATH"
raise FileNotFoundError(msg)
sqlc_path = Path(discovered_path)
if not sqlc_path.exists():
msg = f"sqlc not found at {sqlc_path}"
raise FileNotFoundError(msg)

return [str(sqlc_path)]


def run_sqlc(
schema_path: Path,
queries: list[tuple[str, str]],
*,
dsn: str | None,
debug_path: Path | None = None,
sqlc_path: Path | None = None,
tempdir_path: Path | None = None,
sqlc_command: list[str] | None = None,
) -> SQLCResult:
if not schema_path.exists():
msg = f"Schema file not found: {schema_path}"
Expand All @@ -135,8 +159,11 @@ def run_sqlc(
)

queries = list({q[0]: q for q in queries}.values())
cmd_prefix = _resolve_sqlc_command(sqlc_path, sqlc_command)

with tempfile.TemporaryDirectory() as tempdir:
with tempfile.TemporaryDirectory(
dir=str(tempdir_path) if tempdir_path else None
) as tempdir:
queries_path = Path(tempdir) / "queries.sql"
queries_path.write_text(
"\n\n".join(
Expand All @@ -163,8 +190,10 @@ def run_sqlc(
}
config_path.write_text(json.dumps(sqlc_config, indent=2), encoding="utf-8")

cmd = [*cmd_prefix, "generate", "--file", str(config_path.resolve())]

sqlc_run_result = subprocess.run( # noqa: S603
["/usr/local/bin/sqlc", "generate", "--file", str(config_path)],
cmd,
capture_output=True,
check=False,
)
Expand Down
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading