Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changelog

Added
^^^^^
- ``QuerySet.union()`` — SQL UNION query support for combining results from multiple QuerySets, including support for union across different models, ``union(all=True)`` for duplicates, ``order_by()``, ``limit()``, and ``count()``.
- Tests for model validators. (#2137)

Fixed
Expand Down
20 changes: 20 additions & 0 deletions docs/query.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,23 @@ You can view full example here: :ref:`example_prefetching`

.. autoclass:: tortoise.query_utils.Prefetch
:members:

.. _union:

Union
=====

Tortoise ORM supports SQL ``UNION`` queries to combine results from multiple QuerySets.

Example usage:

.. code-block:: python3

qs1 = Tournament.filter(name__in=["T1", "T2"]).only("id", "name")
qs2 = Reporter.filter(name__in=["R1", "R2"]).only("id", "name")

result = await qs1.union(qs2)

.. autoclass:: tortoise.queryset.UnionQuery
:members:
:inherited-members:
217 changes: 216 additions & 1 deletion tests/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
NotExistOrMultiple,
ParamsError,
)
from tortoise.expressions import F, RawSQL, Subquery
from tortoise.expressions import F, RawSQL, Subquery, Value
from tortoise.functions import Avg

# TODO: Test the many exceptions in QuerySet
Expand Down Expand Up @@ -901,3 +901,218 @@ def test_multiple_objects_returned():
exp_cls: type[NotExistOrMultiple] = MultipleObjectsReturned
assert str(exp_cls("old format")) == "old format"
assert str(exp_cls(Tournament)) == exp_cls.TEMPLATE.format(Tournament.__name__)


@pytest.mark.asyncio
async def test_union_basic(db):
t1 = await Tournament.create(name="T1")
t2 = await Tournament.create(name="T2")
t3 = await Tournament.create(name="T3")
await Tournament.create(name="T4")

qs1 = Tournament.filter(name__in=["T1", "T2"])
qs2 = Tournament.filter(name="T3")

result = await qs1.union(qs2)
assert set(result) == {t1, t2, t3}


@pytest.mark.asyncio
async def test_union_all(db):
t1 = await Tournament.create(name="T1")
await Tournament.create(name="T2")

qs1 = Tournament.filter(name="T1")
qs2 = Tournament.filter(name="T1")

result = await qs1.union(qs2, all=True)
assert list(result) == [t1, t1]


@pytest.mark.asyncio
async def test_union_mixed_models(db):
r1 = await Reporter.create(name="R1")
r2 = await Reporter.create(name="R2")
await Reporter.create(name="R3")
t1 = await Tournament.create(name="T1")
await Tournament.create(name="T2")

qs1 = Tournament.filter(name="T1").only("id", "name")
qs2 = Reporter.filter(name__in=["R1", "R2"]).only("id", "name")

result = await qs1.union(qs2)
assert set(result) == {t1, r1, r2}


@pytest.mark.parametrize(
"orderings,expected_instances",
[
("name", ["t2", "t1", "r1"]),
("-name", ["r1", "t1", "t2"]),
],
)
@pytest.mark.asyncio
async def test_union_order_by(db, orderings, expected_instances):
t1 = await Tournament.create(name="C")
await Reporter.create(name="A")
t2 = await Tournament.create(name="B")
await Reporter.create(name="D")
await Tournament.create(name="E")
r1 = await Reporter.create(name="F")

qs1 = Tournament.filter(id__in=[t1.id, t2.id]).only("id", "name")
qs2 = Reporter.filter(id=r1.id).only("id", "name")

result = await qs1.union(qs2).order_by(*orderings.split(","))

instance_map = {"t1": t1, "t2": t2, "r1": r1}
expected = [instance_map[k] for k in expected_instances]

assert result == expected


@pytest.mark.asyncio
async def test_union_order_by_multiple_fields(db):
t1 = await Tournament.create(name="C")
t2 = await Tournament.create(name="B")
r1 = await Reporter.create(name="C")
await Tournament.create(name="Z")
await Reporter.create(name="Z")

qs1 = Tournament.filter(id__in=[t1.id, t2.id]).only("id", "name")
qs2 = Reporter.filter(id=r1.id).only("id", "name")

result = await qs1.union(qs2).order_by("name", "id")

if r1.id == t1.id:
return

if r1.id > t1.id:
expected = [t2, t1, r1]
else:
expected = [t2, r1, t1]

assert result == expected


@requireCapability(dialect=NotEQ("mssql"))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MSSql limit syntax is different but it's not currently supported in _SetOperation which has its own implementation of _limit_sql() instead of using the dialect SQL. I guess we can leave it as a TODO for later

https://github.com/tortoise/pypika-tortoise/blob/378f6145f7529d88fa58510851d80454b742406e/pypika_tortoise/queries.py#L672

@pytest.mark.asyncio
async def test_union_limit(db):
r1 = await Reporter.create(name="B")
t1 = await Tournament.create(name="A")
await Reporter.create(name="D")
await Tournament.create(name="C")

qs1 = Tournament.all().only("id", "name")
qs2 = Reporter.all().only("id", "name")

result = await qs1.union(qs2).order_by("name").limit(2)
assert list(result) == [t1, r1]


@requireCapability(dialect=NotEQ("mssql"))
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pytest.mark.asyncio
async def test_union_offset(db):
await Tournament.create(name="T1")
await Tournament.create(name="T2")
t3 = await Tournament.create(name="T3")
t4 = await Tournament.create(name="T4")

qs1 = Tournament.filter(name__in=["T1", "T2"]).only("id", "name")
qs2 = Tournament.filter(name__in=["T3", "T4"]).only("id", "name")

result = await qs1.union(qs2).order_by("name").limit(4).offset(2)
assert list(result) == [t3, t4]


@pytest.mark.asyncio
async def test_union_offset_negative_raises(db):
qs1 = Tournament.all().only("id", "name")
qs2 = Tournament.all().only("id", "name")

with pytest.raises(ParamsError, match="Offset should be non-negative number"):
await qs1.union(qs2).offset(-1)


@pytest.mark.asyncio
async def test_union_chained(db):
t1 = await Tournament.create(name="T1")
t2 = await Tournament.create(name="T2")
await Tournament.create(name="T3")
r1 = await Reporter.create(name="R1")
await Reporter.create(name="R2")

qs1 = Tournament.filter(name="T1").only("id", "name")
qs2 = Tournament.filter(name="T2").only("id", "name")
qs3 = Reporter.filter(name="R1").only("id", "name")

result = await qs1.union(qs2).union(qs3)
assert set(result) == {t1, t2, r1}


@pytest.mark.asyncio
async def test_union_count(db):
await Tournament.create(name="T1")
await Reporter.create(name="R1")
await Tournament.create(name="T2")
await Reporter.create(name="R2")

qs1 = Tournament.filter(name="T1").only("id")
qs2 = Reporter.filter(name="R1").only("id")

assert await qs1.union(qs2).count() == 2


@pytest.mark.asyncio
async def test_union_different_select_fields_raises(db):
await Tournament.create(name="T1")

qs1 = Tournament.filter(name="T1").only("name")
qs2 = Tournament.filter(name="T1").only("desc")

with pytest.raises(ParamsError, match="Union queries must have the same select fields"):
await qs1.union(qs2)


@pytest.mark.asyncio
async def test_union_different_fields__in_different_models_raises(db):
await Tournament.create(name="T1")
await Reporter.create(name="R1")

qs1 = Tournament.all()
qs2 = Reporter.all()

with pytest.raises(ParamsError, match="Union queries must have the same select fields"):
await qs1.union(qs2)


@pytest.mark.asyncio
async def test_union_order_by_field_not_in_select_raises(db):
await Tournament.create(name="T1")

qs1 = Tournament.filter(name="T1").only("id", "name")
qs2 = Tournament.filter(name="T1").only("id", "name")

qs = qs1.union(qs2)
with pytest.raises(ParamsError, match="Order by field must be in the select list"):
await qs.order_by("desc")


@pytest.mark.asyncio
async def test_union_with_annotate_raises(db):
await Tournament.create(name="T1")
await Reporter.create(name="R1")

qs1 = (
Tournament.filter(name="T1")
.annotate(annotated_value=Value(1))
.only("id", "name", "annotated_value")
)
qs2 = (
Reporter.filter(name="R1")
.annotate(annotated_value=Value(1))
.only("id", "name", "annotated_value")
)

with pytest.raises(ParamsError, match="Union queries do not support annotations"):
await qs1.union(qs2)
22 changes: 22 additions & 0 deletions tortoise/backends/base/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,28 @@ async def execute_select(
await self._execute_prefetch_queries(instance_list)
return instance_list

async def execute_union(
self, sql: str, app_field: str, model_field: str, models: set[type[Model]]
) -> list:
_, raw_results = await self.db.execute_query(sql)
instance_list = []

for row_idx, row in enumerate(raw_results):
if row_idx != 0 and row_idx % CHUNK_SIZE == 0:
# Forcibly yield to the event loop to avoid blocking the event loop
# when selecting a large number of rows
await asyncio.sleep(0)

for model in models:
if (
model._meta.app == row[app_field]
and model._meta._model.__name__ == row[model_field]
):
instance_list.append(model._init_from_db(**row))
break

return instance_list

def _prepare_insert_columns(
self, include_generated: bool = False
) -> tuple[list[str], list[str]]:
Expand Down
Loading
Loading