diff --git a/docs/index.md b/docs/index.md index 76ea30e5..5062ea6f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -3,7 +3,7 @@ Add querystring filters to your api endpoints and show them in the swagger UI. The supported backends are [SQLAlchemy](https://github.com/sqlalchemy/sqlalchemy), - [MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie). +[MongoEngine](https://github.com/MongoEngine/mongoengine) and [beanie](https://github.com/BeanieODM/beanie). ## Example @@ -133,6 +133,34 @@ class UserFilter(Filter): address__country: Optional[str] ``` +### Extra fields + +Sometimes, you may need to add extra fields to your filter for custom logic. To do this, follow these steps: + +1. Define the extra fields in your filter class. +2. List these fields in the `extra_fields` constant. +3. Access these fields in your endpoint to implement custom filtering logic. + +```python +class UserFilter(Filter): + name: Optional[str] + is_not_active: Optional[bool] + + class Constants(Filter.Constants): + model = User + extra_fields = ["is_not_active"] + +@app.get("/users", response_model=list[UserOut]) +async def get_users(user_filter: UserFilter = FilterDepends(UserFilter), db: AsyncSession = Depends(get_db)) -> Any: + query = user_filter.filter(select(User)) + + if user_filter.is_not_active is not None: + query = query.where(User.is_active.is_(not user_filter.is_not_active)) + + result = await db.execute(query) + return result.scalars().all() +``` + ## Order by There is a specific field on the filter class that can be used for ordering. The default name is `order_by` and it diff --git a/fastapi_filter/base/filter.py b/fastapi_filter/base/filter.py index ae00e774..fccfc2ec 100644 --- a/fastapi_filter/base/filter.py +++ b/fastapi_filter/base/filter.py @@ -49,6 +49,7 @@ class Constants: # pragma: no cover ordering_field_name: str = "order_by" search_model_fields: list[str] search_field_name: str = "search" + extra_fields: list[str] = [] prefix: str original_filter: type["BaseFilterModel"] @@ -59,6 +60,8 @@ def filter(self, query): # pragma: no cover def filtering_fields(self): fields = self.model_dump(exclude_none=True, exclude_unset=True) fields.pop(self.Constants.ordering_field_name, None) + for field_name in self.Constants.extra_fields: + fields.pop(field_name, None) return fields.items() def sort(self, query): # pragma: no cover diff --git a/tests/test_beanie/conftest.py b/tests/test_beanie/conftest.py index b4698103..7fc1bcff 100644 --- a/tests/test_beanie/conftest.py +++ b/tests/test_beanie/conftest.py @@ -31,6 +31,7 @@ class User(Document): name: Optional[str] = None email: Optional[EmailStr] = None age: int + is_active: bool = True address: Optional[Link[Address]] = None favorite_sports: Optional[list[Link[Sport]]] = [] @@ -86,12 +87,14 @@ async def users( await User( name=None, age=21, + is_active=False, created_at=datetime(2021, 12, 1), favorite_sports=sports, ).save(link_rule=WriteRules.WRITE), await User( name="Mr Praline", age=33, + is_active=False, created_at=datetime(2021, 12, 1), address=Address(street="22 rue Bellier", city="Nantes", country="France"), favorite_sports=[sports[0]], @@ -154,6 +157,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type] age__gt: Optional[int] = None age__gte: Optional[int] = None age__in: Optional[list[int]] = None + gender: Optional[str] = None + is_not_active: Optional[bool] = None address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type] with_prefix("address", AddressFilter), ) @@ -164,6 +169,7 @@ class Constants(MongoFilter.Constants): # type: ignore[name-defined] search_model_fields = ["name", "email"] # noqa: RUF012 search_field_name = "search" ordering_field_name = "order_by" + extra_fields = ["is_not_active"] yield UserFilter diff --git a/tests/test_beanie/test_filter.py b/tests/test_beanie/test_filter.py index 4f600a50..174ad606 100644 --- a/tests/test_beanie/test_filter.py +++ b/tests/test_beanie/test_filter.py @@ -22,12 +22,22 @@ [{"address": {"city": "San Francisco"}}, 1], [{"search": "Mr"}, 2], [{"search": "mr"}, 2], + [{"is_not_active": True}, 2], + [{"is_not_active": False}, 4], + [{"is_not_active": None}, 6], + [{"gender": "O"}, 0], ], ) @pytest.mark.usefixtures("sports", "users") @pytest.mark.asyncio async def test_basic_filter(User, UserFilter, AddressFilter, filter_, expected_count): - query = UserFilter(**filter_).filter(User.find({})) + query = User.find({}) + user_filter = UserFilter(**filter_) + + if user_filter.is_not_active is not None: + query = query.find({"is_active": {"$ne": user_filter.is_not_active}}) + + query = user_filter.filter(query) assert await query.count() == expected_count diff --git a/tests/test_mongoengine/conftest.py b/tests/test_mongoengine/conftest.py index 7ea1ead2..b8a4acd3 100644 --- a/tests/test_mongoengine/conftest.py +++ b/tests/test_mongoengine/conftest.py @@ -53,6 +53,7 @@ class User(Document): name = fields.StringField(null=True) email = fields.EmailField() age = fields.IntField() + is_active = fields.BooleanField(default=True) address = fields.ReferenceField(Address) favorite_sports = fields.ListField(fields.ReferenceField(Sport)) @@ -94,12 +95,14 @@ def users(User, Address, sports): User( name=None, age=21, + is_active=False, created_at=datetime(2021, 12, 1), favorite_sports=sports, ).save(), User( name="Mr Praline", age=33, + is_active=False, created_at=datetime(2021, 12, 1), address=Address(street="22 rue Bellier", city="Nantes", country="France").save(), favorite_sports=[sports[0]], @@ -162,6 +165,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type] age__gt: Optional[int] = None age__gte: Optional[int] = None age__in: Optional[list[int]] = None + is_not_active: Optional[bool] = None + gender: Optional[str] = None address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type] with_prefix("address", AddressFilter) ) @@ -172,6 +177,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined] search_model_fields = ["name", "email"] search_field_name = "search" ordering_field_name = "order_by" + extra_fields = ["is_not_active"] yield UserFilter diff --git a/tests/test_mongoengine/test_filter.py b/tests/test_mongoengine/test_filter.py index 18daff04..d93f1d2e 100644 --- a/tests/test_mongoengine/test_filter.py +++ b/tests/test_mongoengine/test_filter.py @@ -1,5 +1,6 @@ from urllib.parse import urlencode +import mongoengine import pytest from fastapi import status @@ -22,11 +23,20 @@ [{"address": {"city": "San Francisco"}}, 1], [{"search": "Mr"}, 2], [{"search": "mr"}, 2], + [{"is_not_active": True}, 2], + [{"is_not_active": False}, 4], + [{"is_not_active": None}, 6], ], ) @pytest.mark.usefixtures("Address", "users") def test_basic_filter(User, UserFilter, filter_, expected_count): - query = UserFilter(**filter_).filter(User.objects()) + query = User.objects() + user_filter = UserFilter(**filter_) + + if user_filter.is_not_active is not None: + query = query.filter(is_active__ne=user_filter.is_not_active) + + query = user_filter.filter(query) assert query.count() == expected_count @@ -76,3 +86,9 @@ async def test_required_filter(test_client, filter_, expected_status_code): error_json = response.json() assert "detail" in error_json assert isinstance(error_json["detail"], list) + + +@pytest.mark.usefixtures("users") +def test_raise_invalid_query_error(User, UserFilter): + with pytest.raises(mongoengine.errors.InvalidQueryError, match='Cannot resolve field "gender"'): + UserFilter(gender="F").filter(User.objects()).count() diff --git a/tests/test_sqlalchemy/conftest.py b/tests/test_sqlalchemy/conftest.py index dec5f4d1..28acbd7b 100644 --- a/tests/test_sqlalchemy/conftest.py +++ b/tests/test_sqlalchemy/conftest.py @@ -63,6 +63,7 @@ class User(Base): # type: ignore[misc, valid-type] updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False) name = Column(String) age = Column(Integer, nullable=False) + is_active = Column(Boolean, nullable=False, default=True) address_id = Column(Integer, ForeignKey("addresses.id")) address: Mapped[Address] = relationship(Address, backref="users", lazy="joined") # type: ignore[valid-type] favorite_sports: Mapped[Sport] = relationship( # type: ignore[valid-type] @@ -117,11 +118,13 @@ async def users(session, User, Address): User( name=None, age=21, + is_active=False, created_at=datetime(2021, 12, 1), ), User( name="Mr Praline", age=33, + is_active=False, created_at=datetime(2021, 12, 1), address=Address(street="22 rue Bellier", city="Nantes", country="France"), ), @@ -274,6 +277,8 @@ class UserFilter(Filter): # type: ignore[misc, valid-type] age__gt: Optional[int] = None age__gte: Optional[int] = None age__in: Optional[list[int]] = None + gender: Optional[str] = None + is_not_active: Optional[bool] = None address: Optional[AddressFilter] = FilterDepends( # type: ignore[valid-type] with_prefix("address", AddressFilter), by_alias=True ) @@ -284,6 +289,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined] model = User search_model_fields = ["name"] search_field_name = "search" + extra_fields = ["is_not_active"] yield UserFilter diff --git a/tests/test_sqlalchemy/test_filter.py b/tests/test_sqlalchemy/test_filter.py index 3217a60b..561085b1 100644 --- a/tests/test_sqlalchemy/test_filter.py +++ b/tests/test_sqlalchemy/test_filter.py @@ -35,13 +35,21 @@ [{"address_id__isnull": True}, 1], [{"search": "Mr"}, 2], [{"search": "mr"}, 2], + [{"is_not_active": True}, 2], + [{"is_not_active": False}, 4], + [{"is_not_active": None}, 6], ], ) @pytest.mark.usefixtures("users") @pytest.mark.asyncio async def test_filter(session, Address, User, UserFilter, filter_, expected_count): query = select(User).outerjoin(Address) - query = UserFilter(**filter_).filter(query) + user_filter = UserFilter(**filter_) + + if user_filter.is_not_active is not None: + query = query.filter(User.is_active.is_(not user_filter.is_not_active)) + + query = user_filter.filter(query) result = await session.execute(query) assert len(result.scalars().unique().all()) == expected_count @@ -111,3 +119,14 @@ async def test_required_filter(test_client, filter_, expected_status_code): error_json = response.json() assert "detail" in error_json assert isinstance(error_json["detail"], list) + + +@pytest.mark.usefixtures("users") +@pytest.mark.asyncio +async def test_raise_attribute_error(session, User, UserFilter): + with pytest.raises(AttributeError, match="type object 'User' has no attribute 'gender'"): + query = select(User) + user_filter = UserFilter(gender="M") + + query = user_filter.filter(query) + await session.execute(query)