Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions fastapi_filter/base/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,14 @@

from fastapi import Depends
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel, ConfigDict, ValidationError, ValidationInfo, create_model, field_validator
from pydantic import (
BaseModel,
ConfigDict,
ValidationError,
ValidationInfo,
create_model,
field_validator,
)
from pydantic.fields import FieldInfo

UNION_TYPES: list = [Union]
Expand All @@ -17,6 +24,16 @@
UNION_TYPES.append(UnionType)


class PaginationLimitOffsetModel(BaseModel):
limit_field: str = "limit" # Default field names
offset_field: str = "offset"


class PaginationPageNumberPageSizeModel(BaseModel):
page_field: str = "page"
size_field: str = "size"


class BaseFilterModel(BaseModel, extra="forbid"):
"""Abstract base filter class.

Expand Down Expand Up @@ -51,19 +68,26 @@ class Constants: # pragma: no cover
search_field_name: str = "search"
prefix: str
original_filter: type["BaseFilterModel"]
pagination_field_model: Union[
PaginationLimitOffsetModel, PaginationPageNumberPageSizeModel, None
] = None

def filter(self, query): # pragma: no cover
...

@property
def filtering_fields(self):
fields = self.model_dump(exclude_none=True, exclude_unset=True)
fields = self.model_dump(exclude_none=True, exclude_unset=False)
fields.pop(self.Constants.ordering_field_name, None)
return fields.items()

def sort(self, query): # pragma: no cover
...

@property
def pagination_field_model_value(self):
return self.pagination_field_model

@property
def ordering_values(self):
"""Check that the ordering field is present on the class definition."""
Expand Down Expand Up @@ -172,7 +196,9 @@ class MainFilter(BaseModel):
"""

class NestedFilter(Filter): # type: ignore[misc, valid-type]
model_config = ConfigDict(extra="forbid", alias_generator=lambda string: f"{prefix}__{string}")
model_config = ConfigDict(
extra="forbid", alias_generator=lambda string: f"{prefix}__{string}"
)

class Constants(Filter.Constants): # type: ignore[name-defined]
...
Expand Down Expand Up @@ -212,7 +238,9 @@ def _list_to_str_fields(Filter: type[BaseFilterModel]):
return ret


def FilterDepends(Filter: type[BaseFilterModel], *, by_alias: bool = False, use_cache: bool = True) -> Any:
def FilterDepends(
Filter: type[BaseFilterModel], *, by_alias: bool = False, use_cache: bool = True
) -> Any:
"""Use a hack to support lists in filters.

FastAPI doesn't support it yet: https://github.com/tiangolo/fastapi/issues/50
Expand All @@ -224,14 +252,20 @@ def FilterDepends(Filter: type[BaseFilterModel], *, by_alias: bool = False, use_
and formatted as a list of <type>?)
"""
fields = _list_to_str_fields(Filter)
GeneratedFilter: type[BaseFilterModel] = create_model(Filter.__class__.__name__, **fields)
GeneratedFilter: type[BaseFilterModel] = create_model(
Filter.__class__.__name__, **fields
)

class FilterWrapper(GeneratedFilter): # type: ignore[misc,valid-type]
def __new__(cls, *args, **kwargs):
try:
instance = GeneratedFilter(*args, **kwargs)
data = instance.model_dump(exclude_unset=True, exclude_defaults=True, by_alias=by_alias)
if original_filter := getattr(Filter.Constants, "original_filter", None):
data = instance.model_dump(
exclude_unset=False, exclude_defaults=True, by_alias=by_alias
)
if original_filter := getattr(
Filter.Constants, "original_filter", None
):
prefix = f"{Filter.Constants.prefix}__"
stripped = {k.removeprefix(prefix): v for k, v in data.items()}
return original_filter(**stripped)
Expand Down
42 changes: 37 additions & 5 deletions fastapi_filter/contrib/sqlalchemy/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from typing import Union
from warnings import warn

from pydantic import ValidationInfo, field_validator
from pydantic import BaseModel, ValidationInfo, field_validator
from sqlalchemy import or_
from sqlalchemy.orm import Query
from sqlalchemy.sql.selectable import Select

from ...base.filter import BaseFilterModel
from ...base.filter import (
BaseFilterModel,
PaginationLimitOffsetModel,
PaginationPageNumberPageSizeModel,
)


def _backward_compatible_value_for_like_and_ilike(value: str):
Expand Down Expand Up @@ -38,8 +42,14 @@ def _backward_compatible_value_for_like_and_ilike(value: str):
"isnull": lambda value: ("is_", None) if value is True else ("is_not", None),
"lt": lambda value: ("__lt__", value),
"lte": lambda value: ("__le__", value),
"like": lambda value: ("like", _backward_compatible_value_for_like_and_ilike(value)),
"ilike": lambda value: ("ilike", _backward_compatible_value_for_like_and_ilike(value)),
"like": lambda value: (
"like",
_backward_compatible_value_for_like_and_ilike(value),
),
"ilike": lambda value: (
"ilike",
_backward_compatible_value_for_like_and_ilike(value),
),
# XXX(arthurio): Mysql excludes None values when using `in` or `not in` filters.
"not": lambda value: ("is_not", value),
"not_in": lambda value: ("not_in", value),
Expand Down Expand Up @@ -115,7 +125,9 @@ def filter(self, query: Union[Query, Select]):
else:
operator = "__eq__"

if field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"):
if field_name == self.Constants.search_field_name and hasattr(
self.Constants, "search_model_fields"
):
search_filters = [
getattr(self.Constants.model, field).ilike(f"%{value}%")
for field in self.Constants.search_model_fields
Expand All @@ -142,3 +154,23 @@ def sort(self, query: Union[Query, Select]):
query = query.order_by(getattr(order_by_field, direction)())

return query

def paginate(self, query: Union[Query, Select]):
pagination = self.pagination_field_model_value
if not pagination:
return query
if isinstance(pagination, PaginationLimitOffsetModel):
limit_field = pagination.limit_field
offset_field = pagination.offset_field
limit = getattr(self, limit_field)
offset = getattr(self, offset_field)
elif isinstance(pagination, PaginationPageNumberPageSizeModel):
page_field = pagination.page_field
size_field = pagination.size_field

page = getattr(self, page_field)
size = getattr(self, size_field)
offset = (page - 1) * size
limit = size
query = query.offset(offset).limit(limit)
return query
Loading