From 87f3cc79032cb82617c20ba624033247fa1e4144 Mon Sep 17 00:00:00 2001 From: agusmdev Date: Tue, 14 May 2024 21:04:32 -0300 Subject: [PATCH 1/3] Add declarative joins --- fastapi_filter/contrib/sqlalchemy/filter.py | 9 ++++++++- tests/test_sqlalchemy/conftest.py | 13 ++++++++++--- tests/test_sqlalchemy/test_filter.py | 8 ++++---- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index 9d5d4d99..b145181f 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -107,7 +107,14 @@ def filter(self, query: Union[Query, Select]): for field_name, value in self.filtering_fields: field_value = getattr(self, field_name) if isinstance(field_value, Filter): - query = field_value.filter(query) + if field_value.model_dump(exclude_unset=True, exclude_none=True): + join_kwargs = getattr(self.Constants, "join_kwargs", {}) + if join_kwargs and field_name in join_kwargs: + join_kwargs = join_kwargs[field_name] + join_kwargs["target"] = join_kwargs.pop("target", field_value.Constants.model) + query = query.join(**join_kwargs) + + query = field_value.filter(query) else: if "__" in field_name: field_name, operator = field_name.split("__") diff --git a/tests/test_sqlalchemy/conftest.py b/tests/test_sqlalchemy/conftest.py index 044583f9..ed67248d 100644 --- a/tests/test_sqlalchemy/conftest.py +++ b/tests/test_sqlalchemy/conftest.py @@ -285,6 +285,13 @@ class Constants(Filter.Constants): # type: ignore[name-defined] model = User search_model_fields = ["name"] search_field_name = "search" + join_kwargs = { + "address": { + "target": AddressFilter.Constants.model, + "onclause": AddressFilter.Constants.model.id == User.address_id, + "isouter": True, + }, + } yield UserFilter @@ -345,7 +352,7 @@ async def get_users( user_filter: UserFilter = FilterDepends(UserFilter), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.filter(select(User)) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() @@ -354,7 +361,7 @@ async def get_users_by_alias( user_filter: UserFilter = FilterDepends(UserFilterByAlias, by_alias=True), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.filter(select(User)) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() @@ -363,7 +370,7 @@ async def get_users_with_order_by( user_filter: UserFilterOrderBy = FilterDepends(UserFilterOrderBy), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.sort(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.sort(select(User)) # type: ignore[attr-defined] query = user_filter.filter(query) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() diff --git a/tests/test_sqlalchemy/test_filter.py b/tests/test_sqlalchemy/test_filter.py index 3217a60b..7202133e 100644 --- a/tests/test_sqlalchemy/test_filter.py +++ b/tests/test_sqlalchemy/test_filter.py @@ -39,8 +39,8 @@ ) @pytest.mark.usefixtures("users") @pytest.mark.asyncio -async def test_filter(session, Address, User, UserFilter, filter_, expected_count): - query = select(User).outerjoin(Address) +async def test_filter(session, User, UserFilter, filter_, expected_count): + query = select(User) query = UserFilter(**filter_).filter(query) result = await session.execute(query) assert len(result.scalars().unique().all()) == expected_count @@ -55,8 +55,8 @@ async def test_filter(session, Address, User, UserFilter, filter_, expected_coun ) @pytest.mark.usefixtures("users") @pytest.mark.asyncio -async def test_filter_deprecation_like_and_ilike(session, Address, User, UserFilter, filter_, expected_count): - query = select(User).outerjoin(Address) +async def test_filter_deprecation_like_and_ilike(session, User, UserFilter, filter_, expected_count): + query = select(User) with pytest.warns(DeprecationWarning, match="like and ilike operators."): query = UserFilter(**filter_).filter(query) result = await session.execute(query) From b63b3a5ddd1daba9b5684e28219ce36a684dae77 Mon Sep 17 00:00:00 2001 From: agusmdev Date: Thu, 16 May 2024 09:59:20 -0300 Subject: [PATCH 2/3] Update join_kwargs field name to be joins --- fastapi_filter/contrib/sqlalchemy/filter.py | 10 +++++----- tests/test_sqlalchemy/conftest.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index b145181f..1d4348f3 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -108,11 +108,11 @@ def filter(self, query: Union[Query, Select]): field_value = getattr(self, field_name) if isinstance(field_value, Filter): if field_value.model_dump(exclude_unset=True, exclude_none=True): - join_kwargs = getattr(self.Constants, "join_kwargs", {}) - if join_kwargs and field_name in join_kwargs: - join_kwargs = join_kwargs[field_name] - join_kwargs["target"] = join_kwargs.pop("target", field_value.Constants.model) - query = query.join(**join_kwargs) + joins = getattr(self.Constants, "joins", {}) + if joins and field_name in joins: + table_join = joins[field_name] + table_join["target"] = table_join.pop("target", field_value.Constants.model) + query = query.join(**table_join) query = field_value.filter(query) else: diff --git a/tests/test_sqlalchemy/conftest.py b/tests/test_sqlalchemy/conftest.py index ed67248d..a2f750dd 100644 --- a/tests/test_sqlalchemy/conftest.py +++ b/tests/test_sqlalchemy/conftest.py @@ -285,7 +285,7 @@ class Constants(Filter.Constants): # type: ignore[name-defined] model = User search_model_fields = ["name"] search_field_name = "search" - join_kwargs = { + joins = { "address": { "target": AddressFilter.Constants.model, "onclause": AddressFilter.Constants.model.id == User.address_id, From c01ef72cae223033af95a62aa3d7ee5d641d0cf6 Mon Sep 17 00:00:00 2001 From: agusmdev Date: Thu, 16 May 2024 17:36:09 -0300 Subject: [PATCH 3/3] Add check to avoid unnecessary joins on empty nested filters --- fastapi_filter/contrib/sqlalchemy/filter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index 1d4348f3..588bcce2 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -107,7 +107,9 @@ def filter(self, query: Union[Query, Select]): for field_name, value in self.filtering_fields: field_value = getattr(self, field_name) if isinstance(field_value, Filter): - if field_value.model_dump(exclude_unset=True, exclude_none=True): + field_value_dump = field_value.model_dump(exclude_unset=True, exclude_none=True) + # Check if the filter has any value set and in case we have a nested filter check if it's not empty + if field_value_dump and field_value_dump and any(field_value_dump.values()): joins = getattr(self.Constants, "joins", {}) if joins and field_name in joins: table_join = joins[field_name]