From 27b6a17fc6283a17dd29ddac0878ce1a65b747f9 Mon Sep 17 00:00:00 2001 From: robot Date: Sat, 30 Nov 2024 13:26:42 +0300 Subject: [PATCH 1/3] tortoise support --- fastapi_filter/contrib/tortoise/__init__.py | 0 fastapi_filter/contrib/tortoise/filter.py | 110 ++++++ pyproject.toml | 4 +- tests/test_tortoise/__init__.py | 0 tests/test_tortoise/conftest.py | 382 ++++++++++++++++++++ tests/test_tortoise/test_filter.py | 111 ++++++ tests/test_tortoise/test_order_by.py | 269 ++++++++++++++ 7 files changed, 875 insertions(+), 1 deletion(-) create mode 100644 fastapi_filter/contrib/tortoise/__init__.py create mode 100644 fastapi_filter/contrib/tortoise/filter.py create mode 100644 tests/test_tortoise/__init__.py create mode 100644 tests/test_tortoise/conftest.py create mode 100644 tests/test_tortoise/test_filter.py create mode 100644 tests/test_tortoise/test_order_by.py diff --git a/fastapi_filter/contrib/tortoise/__init__.py b/fastapi_filter/contrib/tortoise/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_filter/contrib/tortoise/filter.py b/fastapi_filter/contrib/tortoise/filter.py new file mode 100644 index 00000000..0f281270 --- /dev/null +++ b/fastapi_filter/contrib/tortoise/filter.py @@ -0,0 +1,110 @@ +from functools import reduce +from typing import Union +from warnings import warn +from operator import or_ + +from pydantic import ValidationInfo, field_validator +from tortoise.queryset import QuerySet, Q + + +from ...base.filter import BaseFilterModel + + + + +_orm_operator_transformer = { + "neq": lambda value: ("__not", value), + "gt": lambda value: ("__gt", value), + "gte": lambda value: ("__gte", value), + "in": lambda value: ("__in", value), + "isnull": lambda value: ("__isnull", True), + "lt": lambda value: ("__lt", value), + "lte": lambda value: ("__lte", value), + "like": lambda value: ("__contains", value), + "ilike": lambda value: ("__icontains", value), + "not": lambda value: ("__not", value), + "not_in": lambda value: ("__not_in", value), +} +"""Operators à la Django. + +Examples: + my_datetime__gte + count__lt + name__isnull + user_id__in +""" + + +class Filter(BaseFilterModel): + """Base filter for orm related filters. + + All children must set: + ```python + class Constants(Filter.Constants): + model = MyModel + ``` + + It can handle regular field names and Django style operators. + + Example: + ```python + class MyModel: + id: PrimaryKey() + name: StringField(nullable=True) + count: IntegerField() + created_at: DatetimeField() + + class MyModelFilter(Filter): + id: Optional[int] + id__in: Optional[str] + count: Optional[int] + count__lte: Optional[int] + created_at__gt: Optional[datetime] + name__isnull: Optional[bool] + """ + + @field_validator("*", mode="before") + def split_str(cls, value, field: ValidationInfo): + if ( + field.field_name is not None + and ( + field.field_name == cls.Constants.ordering_field_name + or field.field_name.endswith("__in") + or field.field_name.endswith("__not_in") + ) + and isinstance(value, str) + ): + if not value: + # Empty string should return [] not [''] + return [] + return list(value.split(",")) + return value + + def filter(self, query: QuerySet): + for field_name, value in self.filtering_fields: + field_value = getattr(self, field_name) + if isinstance(field_value, Filter): + query = field_value.filter(query) + else: + if "__" in field_name: + field_name, operator = field_name.split("__") + operator, value = _orm_operator_transformer[operator](value) + else: + operator = "" + + if field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"): + search_filters = [ + {f'{field}__icontains': value} + for field in self.Constants.search_model_fields + ] + query = query.filter(reduce(or_, [Q(**filt) for filt in search_filters])) + else: + query = query.filter(**{f'{field_name}{operator}': value}) + + return query + + def sort(self, query: QuerySet): + if not self.ordering_values: + return query + + return query.order_by(*self.ordering_values) diff --git a/pyproject.toml b/pyproject.toml index e1cd9d91..172878d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,6 +112,7 @@ classifiers = [ SQLAlchemy = {version = ">=1.4.36,<2.1.0", optional = true} fastapi = ">=0.100.0,<1.0" mongoengine = {version = ">=0.24.1,<0.28.0", optional = true} +tortoise-orm = {version = ">=0.22.1", optional = true} pydantic = ">=2.0.0,<3.0.0" python = ">=3.9,<4.0" @@ -145,7 +146,8 @@ pydantic = {extras = ["email"], version="^2.7.1"} [tool.poetry.extras] mongoengine = ["mongoengine"] sqlalchemy = ["SQLAlchemy"] -all = ["mongoengine", "SQLAlchemy"] +tortoise-orm = ["tortoise-orm"] +all = ["mongoengine", "SQLAlchemy", "tortoise-orm"] [tool.poetry.group.dev.dependencies] nox = "^2024.0.0" diff --git a/tests/test_tortoise/__init__.py b/tests/test_tortoise/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_tortoise/conftest.py b/tests/test_tortoise/conftest.py new file mode 100644 index 00000000..3c74b6e0 --- /dev/null +++ b/tests/test_tortoise/conftest.py @@ -0,0 +1,382 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from datetime import datetime +from typing import Optional, AsyncGenerator + +import pytest +import pytest_asyncio +from fastapi import Depends, FastAPI, Query +from pydantic import BaseModel, ConfigDict, Field, field_validator +from tortoise import Model, fields, Tortoise, generate_config +from tortoise.contrib.fastapi import RegisterTortoise + +from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter.contrib.sqlalchemy import Filter as SQLAlchemyFilter + + +@pytest.fixture(scope="session") +def sqlite_file_path(tmp_path_factory): + file_path = tmp_path_factory.mktemp("data") / "fastapi_filter_test.sqlite" + yield file_path + + +@pytest.fixture(scope="session") +def database_url(sqlite_file_path) -> str: + return f"sqlite+aiosqlite:///{sqlite_file_path}" + + +@pytest.fixture(scope="session") +async def init(database_url): + # Here we create a SQLite DB using file "db.sqlite3" + # also specify the app name of "models" + # which contain models from "app.models" + await Tortoise.init( + db_url=database_url, + modules={'models': ['app.models']} + ) + # Generate the schema + await Tortoise.generate_schemas() + + + + +@pytest.fixture(scope="session") +def User(Address, FavoriteSport, Sport): + class User(Model): # type: ignore[misc, valid-type] + id = fields.IntField(primary_key=True) + created_at = fields.DatetimeField(null=True, auto_now_add=True) + updated_at = fields.DatetimeField(null=True, auto_now=True) + name = fields.CharField(max_length=255) + age = fields.IntField(null=False) + address_id = fields.ForeignKeyField('models.Address', related_name="users") + + return User + + +@pytest.fixture(scope="session") +def Address(): + class Address(Model): # type: ignore[misc, valid-type] + id = fields.IntField(primary_key=True) + street = fields.CharField(128, null=False) + city = fields.CharField(128, null=False) + country = fields.CharField(128, null=False) + + return Address + + +@pytest.fixture(scope="session") +def Sport(): + class Sport(Model): # type: ignore[misc, valid-type] + + id = fields.IntField(primary_key=True) + name = fields.CharField(128, null=False) + is_individual = fields.BooleanField(null=False) + + return Sport + + +@pytest.fixture(scope="session") +def FavoriteSport(): + class FavoriteSport(Model): # type: ignore[misc, valid-type] + user_id = fields.ForeignKeyField('models.User') + sport_id = fields.ForeignKeyField('models.Sport') + + return FavoriteSport + + +@pytest_asyncio.fixture(scope="function") +async def users(sports, User, Address): + user_instances = [ + User( + name=None, + age=21, + created_at=datetime(2021, 12, 1), + ), + User( + name="Mr Praline", + age=33, + created_at=datetime(2021, 12, 1), + address=Address(street="22 rue Bellier", city="Nantes", country="France"), + ), + User( + name="The colonel", + age=90, + created_at=datetime(2021, 12, 2), + address=Address(street="Wrench", city="Bathroom", country="Clue"), + ), + User( + name="Mr Creosote", + age=21, + created_at=datetime(2021, 12, 3), + address=Address(city="Nantes", country="France"), + ), + User( + name="Rabbit of Caerbannog", + age=1, + created_at=datetime(2021, 12, 4), + address=Address(street="1234 street", city="San Francisco", country="United States"), + ), + User( + name="Gumbys", + age=50, + created_at=datetime(2021, 12, 4), + address=Address(street="4567 avenue", city="Denver", country="United States"), + ), + ] + await User.bulk_create(user_instances) + yield user_instances + + +@pytest_asyncio.fixture(scope="function") +async def sports(Sport): + sport_instances = [ + Sport( + name="Ice Hockey", + is_individual=False, + ), + Sport( + name="Tennis", + is_individual=True, + ), + ] + await Sport.bulk_create(sport_instances) + yield sports + + +@pytest_asyncio.fixture(scope="function") +async def favorite_sports(sports, users, FavoriteSport): + favorite_sport_instances = [ + FavoriteSport( + user_id=users[0].id, + sport_id=sports[0].id, + ), + FavoriteSport( + user_id=users[0].id, + sport_id=sports[1].id, + ), + FavoriteSport( + user_id=users[1].id, + sport_id=sports[0].id, + ), + FavoriteSport( + user_id=users[2].id, + sport_id=sports[1].id, + ), + ] + await FavoriteSport.bulk_create(favorite_sport_instances) + yield favorite_sport_instances + + +@pytest.fixture(scope="package") +def AddressOut(): + class AddressOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + street: Optional[str] + city: str + country: str + + return AddressOut + + +@pytest.fixture(scope="package") +def UserOut(AddressOut, SportOut): + class UserOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + created_at: datetime + updated_at: datetime + name: Optional[str] + age: int + address: Optional[AddressOut] # type: ignore[valid-type] + favorite_sports: Optional[list[SportOut]] # type: ignore[valid-type] + + return UserOut + + +@pytest.fixture(scope="package") +def SportOut(): + class SportOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + name: str + is_individual: bool + + return SportOut + + +@pytest.fixture(scope="package") +def Filter(): + yield SQLAlchemyFilter + + +@pytest.fixture(scope="package") +def AddressFilter(Address, Filter): + class AddressFilter(Filter): # type: ignore[misc, valid-type] + street__isnull: Optional[bool] = None + city: Optional[str] = None + city__in: Optional[list[str]] = None + country__not_in: Optional[list[str]] = None + + class Constants(Filter.Constants): # type: ignore[name-defined] + model = Address + + yield AddressFilter + + +@pytest.fixture(scope="package") +def UserFilter(User, Filter, AddressFilter): + class UserFilter(Filter): # type: ignore[misc, valid-type] + name: Optional[str] = None + name__neq: Optional[str] = None + name__like: Optional[str] = None + name__ilike: Optional[str] = None + name__in: Optional[list[str]] = None + name__not: Optional[str] = None + name__not_in: Optional[list[str]] = None + name__isnull: Optional[bool] = None + age: Optional[int] = None + age__lt: Optional[int] = None + age__lte: Optional[int] = None + age__gt: Optional[int] = None + age__gte: Optional[int] = None + age__in: Optional[list[int]] = None + address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type] + with_prefix("address", AddressFilter), by_alias=True + ) + address_id__isnull: Optional[bool] = None + search: Optional[str] = None + + class Constants(Filter.Constants): # type: ignore[name-defined] + model = User + search_model_fields = ["name"] + search_field_name = "search" + + yield UserFilter + + +@pytest.fixture(scope="package") +def UserFilterByAlias(UserFilter, AddressFilter): + class UserFilterByAlias(UserFilter): # type: ignore[misc, valid-type] + address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type] + with_prefix("address", AddressFilter), by_alias=True + ) + + yield UserFilterByAlias + + +@pytest.fixture(scope="package") +def SportFilter(Sport, Filter): + class SportFilter(Filter): # type: ignore[misc, valid-type] + name: Optional[str] = Field(Query(description="Name of the sport", default=None)) + is_individual: bool + bogus_filter: Optional[str] = None + + class Constants(Filter.Constants): # type: ignore[name-defined] + model = Sport + + @field_validator("bogus_filter") + def throw_exception(cls, value): + if value: + raise ValueError("You can't use this bogus filter") + + yield SportFilter + + +@pytest.fixture(scope="package") +def app( + Address, + FavoriteSport, + Sport, + SportFilter, + SportOut, + User, + UserFilter, + UserFilterByAlias, + UserFilterCustomOrderBy, + UserFilterOrderBy, + UserFilterOrderByWithDefault, + UserFilterRestrictedOrderBy, + UserOut, +): + @asynccontextmanager + async def lifespan_test(app: FastAPI) -> AsyncGenerator[None, None]: + config = generate_config( + os.getenv("TORTOISE_TEST_DB", "sqlite://:memory:"), + app_modules={"models": ["models"]}, + testing=True, + connection_label="models", + ) + async with RegisterTortoise( + app=app, + config=config, + generate_schemas=True, + add_exception_handlers=True, + _create_db=True, + ): + # db connected + yield + # app teardown + # db connections closed + await Tortoise._drop_databases() + + + app = FastAPI(lifespan=lifespan_test) + + @app.get("/users", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users( + user_filter: UserFilter = FilterDepends(UserFilter), # type: ignore[valid-type] + ): + return await user_filter.filter(User.all().select_related('address')) # type: ignore[attr-defined] + + @app.get("/users-by-alias", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users_by_alias( + user_filter: UserFilter = FilterDepends(UserFilterByAlias, by_alias=True), # type: ignore[valid-type] + ): + return await user_filter.filter(User.all().select_related('address')) # type: ignore[attr-defined] + + @app.get("/users_with_order_by", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users_with_order_by( + user_filter: UserFilterOrderBy = FilterDepends(UserFilterOrderBy), # type: ignore[valid-type] + ): + query = user_filter.sort(User.all().select_related('address')) # type: ignore[attr-defined] + return await user_filter.filter(query) # type: ignore[attr-defined] + + @app.get("/users_with_no_order_by", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users_with_no_order_by( + user_filter: UserFilter = FilterDepends(UserFilter), # type: ignore[valid-type] + ): + return await get_users_with_order_by(user_filter) + + @app.get("/users_with_default_order_by", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users_with_default_order_by( + user_filter: UserFilterOrderByWithDefault = FilterDepends( # type: ignore[valid-type] + UserFilterOrderByWithDefault + ), + ): + return await get_users_with_order_by(user_filter) + + @app.get("/users_with_restricted_order_by", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users_with_restricted_order_by( + user_filter: UserFilterRestrictedOrderBy = FilterDepends( # type: ignore[valid-type] + UserFilterRestrictedOrderBy + ), + ): + return await get_users_with_order_by(user_filter) + + @app.get("/users_with_custom_order_by", response_model=list[UserOut]) # type: ignore[valid-type] + async def get_users_with_custom_order_by( + user_filter: UserFilterCustomOrderBy = FilterDepends(UserFilterCustomOrderBy), # type: ignore[valid-type] + ): + return await get_users_with_order_by(user_filter) + + @app.get("/sports", response_model=list[SportOut]) # type: ignore[valid-type] + async def get_sports( + sport_filter: SportFilter = FilterDepends(SportFilter), # type: ignore[valid-type] + ): + return await sport_filter.filter(Sport.all()) # type: ignore[attr-defined] + + yield app diff --git a/tests/test_tortoise/test_filter.py b/tests/test_tortoise/test_filter.py new file mode 100644 index 00000000..0e3c082d --- /dev/null +++ b/tests/test_tortoise/test_filter.py @@ -0,0 +1,111 @@ +from urllib.parse import urlencode + +import pytest +from fastapi import status +from sqlalchemy.future import select + + +@pytest.mark.parametrize( + "filter_,expected_count", + [ + [{"name": "Mr Praline"}, 1], + [{"name__neq": "Mr Praline"}, 4], + [{"name__in": ["Mr Praline", "Mr Creosote", "Gumbys", "Knight"]}, 3], + [{"name__in": "Mr Praline,Mr Creosote,Gumbys,Knight"}, 3], + [{"name__like": "%Mr%"}, 2], + [{"name__ilike": "%mr%"}, 2], + [{"name__like": "%colonel"}, 1], + [{"name__like": "Mr %"}, 2], + [{"name__isnull": True}, 1], + [{"name__isnull": False}, 5], + [{"name__not_in": ["Mr Praline", "Mr Creosote", "Gumbys", "Knight"]}, 2], + [{"name__not_in": "Mr Praline,Mr Creosote,Gumbys,Knight"}, 2], + [{"name__not": "Mr Praline"}, 5], + [{"name__not": "Mr Praline", "age__gte": 21, "age__lt": 50}, 2], + [{"age__in": [1]}, 1], + [{"age__in": [21, 33]}, 3], + [{"address": {"country__not_in": ["France"]}}, 3], + [{"age__in": "1"}, 1], + [{"age__in": "21,33"}, 3], + [{"address": {"country__not_in": "France"}}, 3], + [{"address": {"street__isnull": True}}, 2], + [{"address": {"city__in": ["Nantes", "Denver"]}}, 3], + [{"address": {"city__in": "Nantes,Denver"}}, 3], + [{"address": {"city": "San Francisco"}}, 1], + [{"address_id__isnull": True}, 1], + [{"search": "Mr"}, 2], + [{"search": "mr"}, 2], + ], +) +@pytest.mark.usefixtures("users") +@pytest.mark.asyncio +async def test_filter(Address, User, UserFilter, filter_, expected_count): + query = User.all().select_related('address') + query = await UserFilter(**filter_).filter(query) + assert len(query) == expected_count + + +@pytest.mark.parametrize( + "filter_,expected_count", + [ + [{"name__like": "Mr"}, 2], + [{"name__ilike": "mr"}, 2], + ], +) +@pytest.mark.usefixtures("users") +@pytest.mark.asyncio +async def test_filter_deprecation_like_and_ilike(session, Address, User, UserFilter, filter_, expected_count): + query = User.all().select_related('address') + with pytest.warns(DeprecationWarning, match="like and ilike operators."): + query = await UserFilter(**filter_).filter(query) + assert len(query) == expected_count + + +@pytest.mark.parametrize("uri", ["/users", "/users-by-alias"]) +@pytest.mark.parametrize( + "filter_,expected_count", + [ + [{"name": "Mr Praline"}, 1], + [{"name__in": "Mr Praline,Mr Creosote,Gumbys,Knight"}, 3], + [{"name__isnull": True}, 1], + [{"name__isnull": False}, 5], + [{"name__not_in": "Mr Praline,Mr Creosote,Gumbys,Knight"}, 2], + [{"name__not": "Mr Praline"}, 5], + [{"name__not": "Mr Praline", "age__gte": 21, "age__lt": 50}, 2], + [{"age__in": [1]}, 1], + [{"age__in": "1"}, 1], + [{"age__in": "21,33"}, 3], + [{"address__country__not_in": "France"}, 3], + [{"address__street__isnull": True}, 2], + [{"address__city__in": "Nantes,Denver"}, 3], + [{"address__city": "San Francisco"}, 1], + [{"address_id__isnull": True}, 1], + ], +) +@pytest.mark.usefixtures("users") +@pytest.mark.asyncio +async def test_api(test_client, uri, filter_, expected_count): + response = await test_client.get(f"{uri}?{urlencode(filter_)}") + assert len(response.json()) == expected_count + + +@pytest.mark.parametrize( + "filter_,expected_status_code", + [ + [{"is_individual": True}, status.HTTP_200_OK], + [{"is_individual": False}, status.HTTP_200_OK], + [{}, status.HTTP_422_UNPROCESSABLE_ENTITY], + [{"is_individual": None}, status.HTTP_422_UNPROCESSABLE_ENTITY], + [{"is_individual": True, "bogus_filter": "bad"}, status.HTTP_422_UNPROCESSABLE_ENTITY], + ], +) +@pytest.mark.usefixtures("sports") +@pytest.mark.asyncio +async def test_required_filter(test_client, filter_, expected_status_code): + response = await test_client.get(f"/sports?{urlencode(filter_)}") + assert response.status_code == expected_status_code + + if response.is_error: + error_json = response.json() + assert "detail" in error_json + assert isinstance(error_json["detail"], list) diff --git a/tests/test_tortoise/test_order_by.py b/tests/test_tortoise/test_order_by.py new file mode 100644 index 00000000..b2efea57 --- /dev/null +++ b/tests/test_tortoise/test_order_by.py @@ -0,0 +1,269 @@ +import pytest +from fastapi import status +from pydantic import ValidationError + + +@pytest.mark.parametrize( + "order_by,assert_function", + [ + [None, lambda previous_user, user: True], + [[], lambda previous_user, user: True], + [ + ["name"], + lambda previous_user, user: previous_user.name <= user.name if previous_user.name and user.name else True, + ], + [ + ["-created_at"], + lambda previous_user, user: previous_user.created_at >= user.created_at, + ], + [ + ["age", "-created_at"], + lambda previous_user, user: (previous_user.age < user.age) + or (previous_user.age == user.age and previous_user.created_at >= user.created_at), + ], + ], +) +@pytest.mark.asyncio +async def test_order_by(User, UserFilterOrderBy, users, order_by, assert_function): + query = User.all() + query = UserFilterOrderBy(order_by=order_by).sort(query) + previous_user = None + for user in query: + if not previous_user: + previous_user = user + continue + assert assert_function(previous_user, user) + previous_user = user + + +@pytest.mark.asyncio +async def test_order_by_with_default(User, UserFilterOrderByWithDefault, users): + query = User.all() + query = await UserFilterOrderByWithDefault().sort(query) + previous_user = None + for user in query: + if not previous_user: + previous_user = user + continue + assert previous_user.age <= user.age + previous_user = user + + +@pytest.mark.parametrize( + "order_by,assert_function", + [ + [None, lambda previous_user, user: previous_user["age"] <= user["age"]], + ["", lambda previous_user, user: True], + [ + "name", + lambda previous_user, user: previous_user["name"] <= user["name"] + if previous_user["name"] and user["name"] + else True, + ], + [ + "-created_at", + lambda previous_user, user: previous_user["created_at"] >= user["created_at"], + ], + [ + "age,-name", + lambda previous_user, user: (previous_user["age"] < user["age"]) + or ( + previous_user["age"] == user["age"] + and (previous_user["name"] <= user["name"] if previous_user["name"] and user["name"] else True) + ), + ], + ], +) +@pytest.mark.asyncio +async def test_api_order_by_with_default(test_client, users, order_by, assert_function): + endpoint = "/users_with_default" + if order_by is not None: + endpoint = f"{endpoint}?order_by={order_by}" + response = await test_client.get(endpoint) + previous_user = None + for user in response.json(): + if not previous_user: + previous_user = user + continue + assert assert_function(previous_user, user) + previous_user = user + + +def test_invalid_order_by(UserFilterOrderBy): + with pytest.raises(ValidationError): + UserFilterOrderBy(order_by="invalid") + + +def test_missing_order_by_field(User, UserFilterNoOrderBy): + query = User.all() + with pytest.raises(AttributeError): + UserFilterNoOrderBy().sort(query) + + +@pytest.mark.parametrize( + "order_by,assert_function", + [ + [None, lambda previous_user, user: True], + ["", lambda previous_user, user: True], + [ + "name", + lambda previous_user, user: previous_user.name <= user.name if previous_user.name and user.name else True, + ], + [ + "-created_at", + lambda previous_user, user: previous_user.created_at >= user.created_at, + ], + [ + "age,-name", + lambda previous_user, user: (previous_user.age < user.age) + or ( + previous_user.age == user.age + and (previous_user.name <= user.name if previous_user.name and user.name else True) + ), + ], + ], +) +@pytest.mark.asyncio +async def test_custom_order_by(User, users, UserFilterCustomOrderBy, order_by, assert_function): + query = User.all() + query = await UserFilterCustomOrderBy(custom_order_by=order_by).sort(query) + previous_user = None + for user in query: + if not previous_user: + previous_user = user + continue + assert assert_function(previous_user, user) + previous_user = user + + +@pytest.mark.parametrize( + "order_by", + [ + ["age", "name"], + ["name"], + ["created_at", "name"], + ], +) +def test_restricted_order_by_failure(User, UserFilterRestrictedOrderBy, order_by): + with pytest.raises(ValidationError): + UserFilterRestrictedOrderBy(order_by=order_by) + + +@pytest.mark.parametrize( + "order_by", + [ + None, + [], + ["-created_at"], + ["created_at", "+age"], + ], +) +def test_restricted_order_by_success(User, UserFilterRestrictedOrderBy, order_by): + assert UserFilterRestrictedOrderBy(order_by=order_by) + + +@pytest.mark.parametrize( + "order_by,assert_function", + [ + [None, lambda previous_user, user: True], + ["", lambda previous_user, user: True], + [ + "name", + lambda previous_user, user: previous_user["name"] <= user["name"] + if previous_user["name"] and user["name"] + else True, + ], + [ + "-created_at", + lambda previous_user, user: previous_user["created_at"] >= user["created_at"], + ], + [ + "age,-created_at", + lambda previous_user, user: (previous_user["age"] < user["age"]) + or (previous_user["age"] == user["age"] and previous_user["created_at"] >= user["created_at"]), + ], + ], +) +@pytest.mark.asyncio +async def test_api_order_by(test_client, users, order_by, assert_function): + endpoint = "/users_with_order_by" + if order_by is not None: + endpoint = f"{endpoint}?order_by={order_by}" + response = await test_client.get(endpoint) + previous_user = None + for user in response.json(): + if not previous_user: + previous_user = user + continue + assert assert_function(previous_user, user) + previous_user = user + + +@pytest.mark.asyncio +async def test_api_order_by_invalid_field(test_client, session): + endpoint = "/users_with_order_by?order_by=invalid" + response = await test_client.get(endpoint) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY + + +@pytest.mark.asyncio +async def test_api_no_order_by(test_client, session): + endpoint = "/users_with_no_order_by?order_by=age" + with pytest.raises( + AttributeError, match="Ordering field order_by is not defined. Make sure to add it to your filter class." + ): + await test_client.get(endpoint) + + +@pytest.mark.parametrize( + "order_by,assert_function,status_code", + [ + [None, lambda previous_user, user: True, status.HTTP_200_OK], + ["", lambda previous_user, user: True, status.HTTP_200_OK], + ["name", None, status.HTTP_422_UNPROCESSABLE_ENTITY], + ["age,-name", None, status.HTTP_422_UNPROCESSABLE_ENTITY], + ["-age", lambda previous_user, user: previous_user["age"] >= user["age"], status.HTTP_200_OK], + [ + "age,-created_at", + lambda previous_user, user: (previous_user["age"] < user["age"]) + or (previous_user["age"] == user["age"] and previous_user["created_at"] >= user["created_at"]), + status.HTTP_200_OK, + ], + ], +) +@pytest.mark.asyncio +async def test_api_restricted_order_by(test_client, users, order_by, assert_function, status_code): + endpoint = "/users_with_restricted_order_by" + if order_by is not None: + endpoint = f"{endpoint}?order_by={order_by}" + response = await test_client.get(endpoint) + assert response.status_code == status_code + if status_code == status.HTTP_200_OK: + previous_user = None + for user in response.json(): + if not previous_user: + previous_user = user + continue + assert assert_function(previous_user, user) + previous_user = user + + +@pytest.mark.asyncio +async def test_api_custom_order_by(test_client, session): + endpoint = "/users_with_custom_order_by?custom_order_by=age" + response = await test_client.get(endpoint) + assert response.status_code == status.HTTP_200_OK + + +@pytest.mark.parametrize( + "order_by, ambiguous_field_names", + [ + (["age", "age"], "age, age."), + (["-age", "age"], "-age, age."), + (["name", "-age", "-name", "name"], "name, -name, name"), + (["name", "-age", "name", "age"], "-age, age, name, name"), + ], +) +def test_order_by_with_duplicates_fail(UserFilterOrderBy, order_by, ambiguous_field_names): + with pytest.raises(ValidationError, match=f"The following was ambiguous: {ambiguous_field_names}."): + UserFilterOrderBy(order_by=order_by) From 8bc219045818f93836838c80d0f739b80dfff1ce Mon Sep 17 00:00:00 2001 From: robot Date: Sat, 30 Nov 2024 13:28:07 +0300 Subject: [PATCH 2/3] tortoise support --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index d3fe23d3..c003a871 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - MongoEngine: >=0.24.1, <0.28.0 - SQLAlchemy: >=1.4.36, <2.1.0 +- tortoise-orm: >=0.22.1 ## Installation @@ -30,6 +31,7 @@ pip install fastapi-filter[all] # More selective pip install fastapi-filter[sqlalchemy] pip install fastapi-filter[mongoengine] +pip install fastapi-filter[tortoise-orm] ``` ## Documentation From 8686918797dba478856c187be6eb1027a9fbfa80 Mon Sep 17 00:00:00 2001 From: robot Date: Sat, 30 Nov 2024 13:31:04 +0300 Subject: [PATCH 3/3] tortoise support --- examples/fastapi_filter_tortoise.py | 15 +++++++++++++++ fastapi_filter/contrib/tortoise/__init__.py | 3 +++ 2 files changed, 18 insertions(+) create mode 100644 examples/fastapi_filter_tortoise.py diff --git a/examples/fastapi_filter_tortoise.py b/examples/fastapi_filter_tortoise.py new file mode 100644 index 00000000..859a91a8 --- /dev/null +++ b/examples/fastapi_filter_tortoise.py @@ -0,0 +1,15 @@ +# TODO +import logging +from collections.abc import AsyncIterator +from typing import Any, Optional + +import click +import uvicorn +from faker import Faker +from fastapi import Depends, FastAPI, Query +from pydantic import BaseModel, ConfigDict, Field + +from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter.contrib.tortoise import Filter + +logger = logging.getLogger("uvicorn") diff --git a/fastapi_filter/contrib/tortoise/__init__.py b/fastapi_filter/contrib/tortoise/__init__.py index e69de29b..93429855 100644 --- a/fastapi_filter/contrib/tortoise/__init__.py +++ b/fastapi_filter/contrib/tortoise/__init__.py @@ -0,0 +1,3 @@ +from .filter import Filter + +__all__ = ("Filter",)