diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index 9d5d4d99..588bcce2 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -107,7 +107,16 @@ def filter(self, query: Union[Query, Select]): for field_name, value in self.filtering_fields: field_value = getattr(self, field_name) if isinstance(field_value, Filter): - query = field_value.filter(query) + field_value_dump = field_value.model_dump(exclude_unset=True, exclude_none=True) + # Check if the filter has any value set and in case we have a nested filter check if it's not empty + if field_value_dump and field_value_dump and any(field_value_dump.values()): + joins = getattr(self.Constants, "joins", {}) + if joins and field_name in joins: + table_join = joins[field_name] + table_join["target"] = table_join.pop("target", field_value.Constants.model) + query = query.join(**table_join) + + query = field_value.filter(query) else: if "__" in field_name: field_name, operator = field_name.split("__") diff --git a/tests/test_sqlalchemy/conftest.py b/tests/test_sqlalchemy/conftest.py index 044583f9..a2f750dd 100644 --- a/tests/test_sqlalchemy/conftest.py +++ b/tests/test_sqlalchemy/conftest.py @@ -285,6 +285,13 @@ class Constants(Filter.Constants): # type: ignore[name-defined] model = User search_model_fields = ["name"] search_field_name = "search" + joins = { + "address": { + "target": AddressFilter.Constants.model, + "onclause": AddressFilter.Constants.model.id == User.address_id, + "isouter": True, + }, + } yield UserFilter @@ -345,7 +352,7 @@ async def get_users( user_filter: UserFilter = FilterDepends(UserFilter), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.filter(select(User)) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() @@ -354,7 +361,7 @@ async def get_users_by_alias( user_filter: UserFilter = FilterDepends(UserFilterByAlias, by_alias=True), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.filter(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.filter(select(User)) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() @@ -363,7 +370,7 @@ async def get_users_with_order_by( user_filter: UserFilterOrderBy = FilterDepends(UserFilterOrderBy), # type: ignore[valid-type] db: AsyncSession = Depends(get_db), ): - query = user_filter.sort(select(User).outerjoin(Address)) # type: ignore[attr-defined] + query = user_filter.sort(select(User)) # type: ignore[attr-defined] query = user_filter.filter(query) # type: ignore[attr-defined] result = await db.execute(query) return result.scalars().unique().all() diff --git a/tests/test_sqlalchemy/test_filter.py b/tests/test_sqlalchemy/test_filter.py index 3217a60b..7202133e 100644 --- a/tests/test_sqlalchemy/test_filter.py +++ b/tests/test_sqlalchemy/test_filter.py @@ -39,8 +39,8 @@ ) @pytest.mark.usefixtures("users") @pytest.mark.asyncio -async def test_filter(session, Address, User, UserFilter, filter_, expected_count): - query = select(User).outerjoin(Address) +async def test_filter(session, User, UserFilter, filter_, expected_count): + query = select(User) query = UserFilter(**filter_).filter(query) result = await session.execute(query) assert len(result.scalars().unique().all()) == expected_count @@ -55,8 +55,8 @@ async def test_filter(session, Address, User, UserFilter, filter_, expected_coun ) @pytest.mark.usefixtures("users") @pytest.mark.asyncio -async def test_filter_deprecation_like_and_ilike(session, Address, User, UserFilter, filter_, expected_count): - query = select(User).outerjoin(Address) +async def test_filter_deprecation_like_and_ilike(session, User, UserFilter, filter_, expected_count): + query = select(User) with pytest.warns(DeprecationWarning, match="like and ilike operators."): query = UserFilter(**filter_).filter(query) result = await session.execute(query)