Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 93 additions & 12 deletions tests/migrations/test_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
class FakeConnection:
def __init__(self, dialect: str) -> None:
self.capabilities = Capabilities(dialect)
self.executed: list[str] = []
self.queries: list[str] = []
self.executed_scripts: list[str] = []
self.inserts: list[tuple[str, list]] = []
self.queries: list[tuple[str, list | None]] = []

async def execute_script(self, query: str) -> None:
self.executed.append(query)
self.executed_scripts.append(query)

async def execute_query(self, query: str):
self.queries.append(query)
async def execute_insert(self, query: str, values: list) -> int:
self.inserts.append((query, values))
return 1

async def execute_query(self, query: str, values: list | None = None):
self.queries.append((query, values))
return None, []


Expand All @@ -26,11 +31,17 @@ async def test_recorder_quotes_mysql_identifiers() -> None:
await recorder.record_applied("app", "0001_initial")
await recorder.applied_migrations()

assert connection.executed
assert "INSERT INTO `tortoise_migrations`" in connection.executed[0]
assert "(`app`, `name`, `applied_at`)" in connection.executed[0]
assert connection.inserts
insert_query, insert_values = connection.inserts[0]
assert "INSERT INTO `tortoise_migrations`" in insert_query
assert "(`app`, `name`, `applied_at`)" in insert_query
assert "%s" in insert_query
assert insert_values[0] == "app"
assert insert_values[1] == "0001_initial"

assert connection.queries
assert "SELECT `app`, `name`" in connection.queries[0]
select_query = connection.queries[0][0]
assert "SELECT `app`, `name`" in select_query


@pytest.mark.asyncio
Expand All @@ -40,6 +51,76 @@ async def test_recorder_quotes_mssql_identifiers() -> None:

await recorder.record_unapplied("app", "0001_initial")

assert connection.executed
assert "DELETE FROM [tortoise_migrations]" in connection.executed[0]
assert "WHERE [app] = 'app'" in connection.executed[0]
assert connection.queries
delete_query, delete_values = connection.queries[0]
assert "DELETE FROM [tortoise_migrations]" in delete_query
assert "[app] = ?" in delete_query
assert "[name] = ?" in delete_query
assert delete_values == ["app", "0001_initial"]


@pytest.mark.asyncio
async def test_recorder_uses_parameterized_insert() -> None:
"""Ensure record_applied uses parameterized queries instead of string interpolation.

This is critical for MariaDB compatibility \u2014 MariaDB rejects ISO 8601
datetime strings with timezone info (e.g. '2026-03-04T18:06:51+00:00').
See https://github.com/tortoise/tortoise-orm/issues/2132
"""
from datetime import datetime

connection = FakeConnection("mysql")
recorder = MigrationRecorder(connection)

await recorder.record_applied("models", "0001_init")

assert len(connection.inserts) == 1
query, values = connection.inserts[0]
# Query must use placeholders, not inline values
assert "VALUES (%s, %s, %s)" in query
assert values[0] == "models"
assert values[1] == "0001_init"
assert isinstance(values[2], datetime)


@pytest.mark.asyncio
async def test_recorder_uses_parameterized_delete() -> None:
"""Ensure record_unapplied uses parameterized queries."""
connection = FakeConnection("mysql")
recorder = MigrationRecorder(connection)

await recorder.record_unapplied("models", "0001_init")

assert len(connection.queries) == 1
query, values = connection.queries[0]
assert "WHERE `app` = %s" in query
assert "`name` = %s" in query
assert values == ["models", "0001_init"]


@pytest.mark.asyncio
async def test_recorder_postgres_placeholders() -> None:
connection = FakeConnection("postgres")
recorder = MigrationRecorder(connection)

await recorder.record_applied("app", "0001_initial")

assert len(connection.inserts) == 1
query, values = connection.inserts[0]
assert "VALUES ($1, $2, $3)" in query


@pytest.mark.asyncio
async def test_recorder_sqlite_placeholders() -> None:
connection = FakeConnection("sqlite")
recorder = MigrationRecorder(connection)

await recorder.record_applied("app", "0001_initial")
await recorder.record_unapplied("app", "0001_initial")

insert_query = connection.inserts[0][0]
assert "VALUES (?, ?, ?)" in insert_query

delete_query = connection.queries[0][0]
assert '"app" = ?' in delete_query
assert '"name" = ?' in delete_query
31 changes: 25 additions & 6 deletions tortoise/migrations/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
from tortoise.migrations.graph import MigrationKey
from tortoise.models import Model

# Parameter placeholder per dialect.
_DIALECT_PLACEHOLDER: dict[str, str] = {
"mysql": "%s",
"sqlite": "?",
"postgres": "${}",
"mssql": "?",
"oracle": "?",
}


class MigrationRecorder:
def __init__(self, connection, *, table_name: str = "tortoise_migrations") -> None:
Expand All @@ -25,6 +34,16 @@ def _quote(self, name: str) -> str:
return f"[{name}]"
return f'"{name}"'

def _placeholder(self, pos: int) -> str:
"""Return a positional parameter placeholder suitable for the current dialect.

``pos`` is 1-based (first parameter is 1).
"""
template = _DIALECT_PLACEHOLDER.get(self._dialect, "?")
if "{}" in template:
return template.format(pos)
return template

def _make_model(self, table_name: str) -> type[Model]:
class MigrationRecord(Model):
id = fields.IntField(pk=True)
Expand Down Expand Up @@ -74,21 +93,21 @@ async def applied_migrations(self) -> list[MigrationKey]:
return [MigrationKey(app_label=row["app"], name=row["name"]) for row in rows]

async def record_applied(self, app: str, name: str) -> None:
applied_at = datetime.now(timezone.utc).isoformat()
applied_at = datetime.now(timezone.utc)
query = (
f"INSERT INTO {self._quote(self.table_name)} " # nosec B608
f"({self._quote('app')}, {self._quote('name')}, {self._quote('applied_at')}) "
f"VALUES ('{self._escape(app)}', '{self._escape(name)}', '{applied_at}')"
f"VALUES ({self._placeholder(1)}, {self._placeholder(2)}, {self._placeholder(3)})"
)
await self.connection.execute_script(query)
await self.connection.execute_insert(query, [app, name, applied_at])

async def record_unapplied(self, app: str, name: str) -> None:
query = (
f"DELETE FROM {self._quote(self.table_name)} " # nosec B608
f"WHERE {self._quote('app')} = '{self._escape(app)}' "
f"AND {self._quote('name')} = '{self._escape(name)}'"
f"WHERE {self._quote('app')} = {self._placeholder(1)} "
f"AND {self._quote('name')} = {self._placeholder(2)}"
)
await self.connection.execute_script(query)
await self.connection.execute_query(query, [app, name])

@staticmethod
def _escape(value: str) -> str:
Expand Down
Loading