diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 68d944bea..824928207 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,6 +13,7 @@ Changelog Added ^^^^^ +- ``QuerySet.union()`` — SQL UNION query support for combining results from multiple QuerySets, including support for union across different models, ``union(all=True)`` for duplicates, ``order_by()``, ``limit()``, and ``count()``. - Tests for model validators. (#2137) Fixed diff --git a/docs/query.rst b/docs/query.rst index 75b9dee6f..7c6350bc7 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -349,3 +349,23 @@ You can view full example here: :ref:`example_prefetching` .. autoclass:: tortoise.query_utils.Prefetch :members: + +.. _union: + +Union +===== + +Tortoise ORM supports SQL ``UNION`` queries to combine results from multiple QuerySets. + +Example usage: + +.. code-block:: python3 + + qs1 = Tournament.filter(name__in=["T1", "T2"]).only("id", "name") + qs2 = Reporter.filter(name__in=["R1", "R2"]).only("id", "name") + + result = await qs1.union(qs2) + +.. autoclass:: tortoise.queryset.UnionQuery + :members: + :inherited-members: diff --git a/tests/test_queryset.py b/tests/test_queryset.py index 513979035..166db0b8e 100644 --- a/tests/test_queryset.py +++ b/tests/test_queryset.py @@ -24,7 +24,7 @@ NotExistOrMultiple, ParamsError, ) -from tortoise.expressions import F, RawSQL, Subquery +from tortoise.expressions import F, RawSQL, Subquery, Value from tortoise.functions import Avg # TODO: Test the many exceptions in QuerySet @@ -901,3 +901,218 @@ def test_multiple_objects_returned(): exp_cls: type[NotExistOrMultiple] = MultipleObjectsReturned assert str(exp_cls("old format")) == "old format" assert str(exp_cls(Tournament)) == exp_cls.TEMPLATE.format(Tournament.__name__) + + +@pytest.mark.asyncio +async def test_union_basic(db): + t1 = await Tournament.create(name="T1") + t2 = await Tournament.create(name="T2") + t3 = await Tournament.create(name="T3") + await Tournament.create(name="T4") + + qs1 = Tournament.filter(name__in=["T1", "T2"]) + qs2 = Tournament.filter(name="T3") + + result = await qs1.union(qs2) + assert set(result) == {t1, t2, t3} + + +@pytest.mark.asyncio +async def test_union_all(db): + t1 = await Tournament.create(name="T1") + await Tournament.create(name="T2") + + qs1 = Tournament.filter(name="T1") + qs2 = Tournament.filter(name="T1") + + result = await qs1.union(qs2, all=True) + assert list(result) == [t1, t1] + + +@pytest.mark.asyncio +async def test_union_mixed_models(db): + r1 = await Reporter.create(name="R1") + r2 = await Reporter.create(name="R2") + await Reporter.create(name="R3") + t1 = await Tournament.create(name="T1") + await Tournament.create(name="T2") + + qs1 = Tournament.filter(name="T1").only("id", "name") + qs2 = Reporter.filter(name__in=["R1", "R2"]).only("id", "name") + + result = await qs1.union(qs2) + assert set(result) == {t1, r1, r2} + + +@pytest.mark.parametrize( + "orderings,expected_instances", + [ + ("name", ["t2", "t1", "r1"]), + ("-name", ["r1", "t1", "t2"]), + ], +) +@pytest.mark.asyncio +async def test_union_order_by(db, orderings, expected_instances): + t1 = await Tournament.create(name="C") + await Reporter.create(name="A") + t2 = await Tournament.create(name="B") + await Reporter.create(name="D") + await Tournament.create(name="E") + r1 = await Reporter.create(name="F") + + qs1 = Tournament.filter(id__in=[t1.id, t2.id]).only("id", "name") + qs2 = Reporter.filter(id=r1.id).only("id", "name") + + result = await qs1.union(qs2).order_by(*orderings.split(",")) + + instance_map = {"t1": t1, "t2": t2, "r1": r1} + expected = [instance_map[k] for k in expected_instances] + + assert result == expected + + +@pytest.mark.asyncio +async def test_union_order_by_multiple_fields(db): + t1 = await Tournament.create(name="C") + t2 = await Tournament.create(name="B") + r1 = await Reporter.create(name="C") + await Tournament.create(name="Z") + await Reporter.create(name="Z") + + qs1 = Tournament.filter(id__in=[t1.id, t2.id]).only("id", "name") + qs2 = Reporter.filter(id=r1.id).only("id", "name") + + result = await qs1.union(qs2).order_by("name", "id") + + if r1.id == t1.id: + return + + if r1.id > t1.id: + expected = [t2, t1, r1] + else: + expected = [t2, r1, t1] + + assert result == expected + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_union_limit(db): + r1 = await Reporter.create(name="B") + t1 = await Tournament.create(name="A") + await Reporter.create(name="D") + await Tournament.create(name="C") + + qs1 = Tournament.all().only("id", "name") + qs2 = Reporter.all().only("id", "name") + + result = await qs1.union(qs2).order_by("name").limit(2) + assert list(result) == [t1, r1] + + +@requireCapability(dialect=NotEQ("mssql")) +@pytest.mark.asyncio +async def test_union_offset(db): + await Tournament.create(name="T1") + await Tournament.create(name="T2") + t3 = await Tournament.create(name="T3") + t4 = await Tournament.create(name="T4") + + qs1 = Tournament.filter(name__in=["T1", "T2"]).only("id", "name") + qs2 = Tournament.filter(name__in=["T3", "T4"]).only("id", "name") + + result = await qs1.union(qs2).order_by("name").limit(4).offset(2) + assert list(result) == [t3, t4] + + +@pytest.mark.asyncio +async def test_union_offset_negative_raises(db): + qs1 = Tournament.all().only("id", "name") + qs2 = Tournament.all().only("id", "name") + + with pytest.raises(ParamsError, match="Offset should be non-negative number"): + await qs1.union(qs2).offset(-1) + + +@pytest.mark.asyncio +async def test_union_chained(db): + t1 = await Tournament.create(name="T1") + t2 = await Tournament.create(name="T2") + await Tournament.create(name="T3") + r1 = await Reporter.create(name="R1") + await Reporter.create(name="R2") + + qs1 = Tournament.filter(name="T1").only("id", "name") + qs2 = Tournament.filter(name="T2").only("id", "name") + qs3 = Reporter.filter(name="R1").only("id", "name") + + result = await qs1.union(qs2).union(qs3) + assert set(result) == {t1, t2, r1} + + +@pytest.mark.asyncio +async def test_union_count(db): + await Tournament.create(name="T1") + await Reporter.create(name="R1") + await Tournament.create(name="T2") + await Reporter.create(name="R2") + + qs1 = Tournament.filter(name="T1").only("id") + qs2 = Reporter.filter(name="R1").only("id") + + assert await qs1.union(qs2).count() == 2 + + +@pytest.mark.asyncio +async def test_union_different_select_fields_raises(db): + await Tournament.create(name="T1") + + qs1 = Tournament.filter(name="T1").only("name") + qs2 = Tournament.filter(name="T1").only("desc") + + with pytest.raises(ParamsError, match="Union queries must have the same select fields"): + await qs1.union(qs2) + + +@pytest.mark.asyncio +async def test_union_different_fields__in_different_models_raises(db): + await Tournament.create(name="T1") + await Reporter.create(name="R1") + + qs1 = Tournament.all() + qs2 = Reporter.all() + + with pytest.raises(ParamsError, match="Union queries must have the same select fields"): + await qs1.union(qs2) + + +@pytest.mark.asyncio +async def test_union_order_by_field_not_in_select_raises(db): + await Tournament.create(name="T1") + + qs1 = Tournament.filter(name="T1").only("id", "name") + qs2 = Tournament.filter(name="T1").only("id", "name") + + qs = qs1.union(qs2) + with pytest.raises(ParamsError, match="Order by field must be in the select list"): + await qs.order_by("desc") + + +@pytest.mark.asyncio +async def test_union_with_annotate_raises(db): + await Tournament.create(name="T1") + await Reporter.create(name="R1") + + qs1 = ( + Tournament.filter(name="T1") + .annotate(annotated_value=Value(1)) + .only("id", "name", "annotated_value") + ) + qs2 = ( + Reporter.filter(name="R1") + .annotate(annotated_value=Value(1)) + .only("id", "name", "annotated_value") + ) + + with pytest.raises(ParamsError, match="Union queries do not support annotations"): + await qs1.union(qs2) diff --git a/tortoise/backends/base/executor.py b/tortoise/backends/base/executor.py index a9471d75e..f22885d68 100644 --- a/tortoise/backends/base/executor.py +++ b/tortoise/backends/base/executor.py @@ -154,6 +154,28 @@ async def execute_select( await self._execute_prefetch_queries(instance_list) return instance_list + async def execute_union( + self, sql: str, app_field: str, model_field: str, models: set[type[Model]] + ) -> list: + _, raw_results = await self.db.execute_query(sql) + instance_list = [] + + for row_idx, row in enumerate(raw_results): + if row_idx != 0 and row_idx % CHUNK_SIZE == 0: + # Forcibly yield to the event loop to avoid blocking the event loop + # when selecting a large number of rows + await asyncio.sleep(0) + + for model in models: + if ( + model._meta.app == row[app_field] + and model._meta._model.__name__ == row[model_field] + ): + instance_list.append(model._init_from_db(**row)) + break + + return instance_list + def _prepare_insert_columns( self, include_generated: bool = False ) -> tuple[list[str], list[str]]: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index ef4b5b568..cb34d7150 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -2,14 +2,14 @@ import types from collections import defaultdict -from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable +from collections.abc import AsyncIterator, Callable, Collection, Generator, Iterable, Sequence from copy import copy from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol, TypeVar, cast, overload from pypika_tortoise import JoinType, Order, Table from pypika_tortoise.analytics import Count from pypika_tortoise.functions import Cast -from pypika_tortoise.queries import QueryBuilder +from pypika_tortoise.queries import QueryBuilder, _SetOperation from pypika_tortoise.terms import Case, Field, Star, Term, ValueWrapper from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities @@ -21,7 +21,7 @@ OperationalError, ParamsError, ) -from tortoise.expressions import Expression, Q, RawSQL, ResolveContext, ResolveResult +from tortoise.expressions import Expression, Q, RawSQL, ResolveContext, ResolveResult, Value from tortoise.fields.base import DatabaseDefault from tortoise.fields.relational import ( ForeignKeyFieldInstance, @@ -586,6 +586,15 @@ def distinct(self) -> QuerySet[MODEL]: queryset._distinct = True return queryset + def union(self, *other_qs: QuerySet[Model], all: bool = False) -> UnionQuery[MODEL]: + """ + Return the union of QuerySets. + + :param other_qs: Another QuerySet(s) to union with. + :return: A new UnionQuery representing the union of both QuerySets. + """ + return UnionQuery(self.model, self._db, self, *other_qs, all=all) # type: ignore[arg-type] + def select_for_update( self, nowait: bool = False, @@ -2223,3 +2232,226 @@ def sql(self, params_inline=False) -> str: return insert_sql return ";".join([insert_sql, insert_sql_all]) + + +class UnionCountQuery(AwaitableQuery): + __slots__ = ("_union_query", "_db") + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + union_query: UnionQuery[MODEL], + ) -> None: + super().__init__(model) + self._union_query = union_query + self._db = db + + def _make_query(self) -> None: + self._union_query._make_query() + self.query = self.query.QUERY_CLS.from_(self._union_query._union_query).select( # type:ignore[arg-type] + Count(Star()) + ) + + def __await__(self) -> Generator[Any, None, int]: + self._choose_db_if_not_chosen() + self._make_query() + return self._execute().__await__() + + async def _execute(self) -> int: + _, result = await self._db.execute_query(self.query.get_sql()) + if not result: + return 0 + return list(dict(result[0]).values())[0] + + +class UnionQuery(AwaitableQuery[MODEL]): + __slots__ = ( + "model", + "_models", + "_union_query", + "_selects", + "_db", + "_qs", + "_all", + "_orderings", + "_limit", + "_offset", + ) + + TORTOISE_APP_FIELD = "tortoise_app" + TORTOISE_MODEL_FIELD = "tortoise_model" + + def __init__( + self, + model: type[MODEL], + db: BaseDBAsyncClient, + *querysets: QuerySet[Model], + all: bool = False, + ): + super().__init__(model) + self._models: set[type[Model]] = {model, *(qs.model for qs in querysets)} + self._union_query: QueryBuilder | _SetOperation | None = None + self._selects: list[str] = [] + self._db = db + self._qs = querysets + self._all = all + self._orderings: list[tuple[str, Order]] | None = None + self._limit: int | None = None + self._offset: int | None = None + + @classmethod + def _get_selects(cls, qs: QuerySet[Model] | UnionQuery[Model]) -> list[str]: + return [ + select.name + for select in qs.query._selects + if getattr(select, "alias") not in [cls.TORTOISE_APP_FIELD, cls.TORTOISE_MODEL_FIELD] + ] + + def _make_query(self) -> None: + for qs in self._qs: + if qs._annotations: + raise ParamsError("Union queries do not support annotations") + model_annotations = { + self.TORTOISE_APP_FIELD: Value(qs.model._meta.app), + self.TORTOISE_MODEL_FIELD: Value(qs.model._meta._model.__name__), + } + qs = qs.annotate(**model_annotations) + qs._make_query() + qs.query.wrap_set_operation_queries = False + if not self._union_query: + self._union_query = qs.query + self._selects = self._get_selects(qs) + else: + if self._get_selects(qs) != self._selects: + raise ParamsError("Union queries must have the same select fields") + self._union_query = ( + self._union_query.union_all(qs.query) + if self._all + else self._union_query.union(qs.query) + ) + + if self._union_query is None: + return + + if self._orderings: + for field_name, order in self._orderings: + if field_name not in self._selects: + raise ParamsError("Order by field must be in the select list for union queries") + + self._union_query = self._union_query.orderby(field_name, order=order) + + if self._limit is not None: + self._union_query = self._union_query.limit(self._limit) + + if self._offset is not None: + self._union_query = self._union_query.offset(self._offset) + + def __await__(self) -> Generator[Any, None, Sequence[MODEL]]: + self._choose_db_if_not_chosen() + self._make_query() + return self._execute().__await__() + + async def __aiter__(self: UnionQuery[Any]) -> AsyncIterator[Any]: + for val in await self: + yield val + + async def _execute(self) -> Sequence[MODEL]: + if self._union_query is None: + return [] + + sql = self._union_query.get_sql(self._qs[0].query.QUERY_CLS.SQL_CONTEXT) + instance_list = await self._db.executor_class( + model=self.model, + db=self._db, + ).execute_union(sql, self.TORTOISE_APP_FIELD, self.TORTOISE_MODEL_FIELD, self._models) + return instance_list + + def _clone(self) -> UnionQuery[MODEL]: + union = self.__class__.__new__(self.__class__) + union.model = self.model + union._models = self._models + union._union_query = None + union._selects = self._selects + union._db = self._db + union._qs = self._qs + union._all = self._all + union._orderings = self._orderings + union._limit = self._limit + union._offset = self._offset + return union + + @classmethod + def _parse_orderings(cls, orderings: tuple[str, ...]) -> list[tuple[str, Order]]: + new_ordering = [] + for ordering in orderings: + new_ordering.append(QuerySet._resolve_ordering_string(ordering)) + return new_ordering + + def union(self, *other_qs: QuerySet[Model], all: bool = False) -> UnionQuery[MODEL]: + """ + Return the union of QuerySets. + + :param other_qs: Another QuerySet(s) to union with. + :return: A new UnionQuery representing the union of all QuerySets. + """ + union = self._clone() + union._models = {*union._models, *(qs.model for qs in other_qs)} + union._qs = union._qs + other_qs + union._all = union._all or all + return union + + def order_by(self, *orderings: str) -> UnionQuery[MODEL]: + """ + Accept args to filter by in format like this: + + .. code-block:: python3 + + .order_by('name', '-id') + + A '-' before the name will result in descending sort order, default is ascending. + + :raises FieldError: If unknown field has been provided. + """ + union = self._clone() + union._orderings = self._parse_orderings(orderings) + return union + + def limit(self, limit: int) -> UnionQuery[MODEL]: + """ + Limits UnionQuery to given length. + + :raises ParamsError: Limit should be non-negative number. + """ + if limit < 0: + raise ParamsError("Limit should be non-negative number") + + union = self._clone() + union._limit = limit + return union + + def offset(self, offset: int) -> UnionQuery[MODEL]: + """ + Query offset for UnionQuery. + + :raises ParamsError: Offset should be non-negative number. + """ + if offset < 0: + raise ParamsError("Offset should be non-negative number") + + union = self._clone() + union._offset = offset + return union + + def count(self) -> UnionCountQuery: + """ + Return count of objects in union query. + """ + self._choose_db_if_not_chosen() + union_query_clone = self._clone() + + return UnionCountQuery( + model=self.model, + db=self._db, + union_query=union_query_clone, + )