From 29acf343a9be75284ad180a3a635c5026d0fab83 Mon Sep 17 00:00:00 2001 From: Teut2711 Date: Sun, 4 May 2025 17:44:20 +0530 Subject: [PATCH] added pagination --- fastapi_filter/base/filter.py | 48 ++++++++++++++++++--- fastapi_filter/contrib/sqlalchemy/filter.py | 42 +++++++++++++++--- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/fastapi_filter/base/filter.py b/fastapi_filter/base/filter.py index ae00e774..fc3663bf 100644 --- a/fastapi_filter/base/filter.py +++ b/fastapi_filter/base/filter.py @@ -6,7 +6,14 @@ from fastapi import Depends from fastapi.exceptions import RequestValidationError -from pydantic import BaseModel, ConfigDict, ValidationError, ValidationInfo, create_model, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + ValidationError, + ValidationInfo, + create_model, + field_validator, +) from pydantic.fields import FieldInfo UNION_TYPES: list = [Union] @@ -17,6 +24,16 @@ UNION_TYPES.append(UnionType) +class PaginationLimitOffsetModel(BaseModel): + limit_field: str = "limit" # Default field names + offset_field: str = "offset" + + +class PaginationPageNumberPageSizeModel(BaseModel): + page_field: str = "page" + size_field: str = "size" + + class BaseFilterModel(BaseModel, extra="forbid"): """Abstract base filter class. @@ -51,19 +68,26 @@ class Constants: # pragma: no cover search_field_name: str = "search" prefix: str original_filter: type["BaseFilterModel"] + pagination_field_model: Union[ + PaginationLimitOffsetModel, PaginationPageNumberPageSizeModel, None + ] = None def filter(self, query): # pragma: no cover ... @property def filtering_fields(self): - fields = self.model_dump(exclude_none=True, exclude_unset=True) + fields = self.model_dump(exclude_none=True, exclude_unset=False) fields.pop(self.Constants.ordering_field_name, None) return fields.items() def sort(self, query): # pragma: no cover ... + @property + def pagination_field_model_value(self): + return self.pagination_field_model + @property def ordering_values(self): """Check that the ordering field is present on the class definition.""" @@ -172,7 +196,9 @@ class MainFilter(BaseModel): """ class NestedFilter(Filter): # type: ignore[misc, valid-type] - model_config = ConfigDict(extra="forbid", alias_generator=lambda string: f"{prefix}__{string}") + model_config = ConfigDict( + extra="forbid", alias_generator=lambda string: f"{prefix}__{string}" + ) class Constants(Filter.Constants): # type: ignore[name-defined] ... @@ -212,7 +238,9 @@ def _list_to_str_fields(Filter: type[BaseFilterModel]): return ret -def FilterDepends(Filter: type[BaseFilterModel], *, by_alias: bool = False, use_cache: bool = True) -> Any: +def FilterDepends( + Filter: type[BaseFilterModel], *, by_alias: bool = False, use_cache: bool = True +) -> Any: """Use a hack to support lists in filters. FastAPI doesn't support it yet: https://github.com/tiangolo/fastapi/issues/50 @@ -224,14 +252,20 @@ def FilterDepends(Filter: type[BaseFilterModel], *, by_alias: bool = False, use_ and formatted as a list of ?) """ fields = _list_to_str_fields(Filter) - GeneratedFilter: type[BaseFilterModel] = create_model(Filter.__class__.__name__, **fields) + GeneratedFilter: type[BaseFilterModel] = create_model( + Filter.__class__.__name__, **fields + ) class FilterWrapper(GeneratedFilter): # type: ignore[misc,valid-type] def __new__(cls, *args, **kwargs): try: instance = GeneratedFilter(*args, **kwargs) - data = instance.model_dump(exclude_unset=True, exclude_defaults=True, by_alias=by_alias) - if original_filter := getattr(Filter.Constants, "original_filter", None): + data = instance.model_dump( + exclude_unset=False, exclude_defaults=True, by_alias=by_alias + ) + if original_filter := getattr( + Filter.Constants, "original_filter", None + ): prefix = f"{Filter.Constants.prefix}__" stripped = {k.removeprefix(prefix): v for k, v in data.items()} return original_filter(**stripped) diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index 9d5d4d99..b16057bb 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -2,12 +2,16 @@ from typing import Union from warnings import warn -from pydantic import ValidationInfo, field_validator +from pydantic import BaseModel, ValidationInfo, field_validator from sqlalchemy import or_ from sqlalchemy.orm import Query from sqlalchemy.sql.selectable import Select -from ...base.filter import BaseFilterModel +from ...base.filter import ( + BaseFilterModel, + PaginationLimitOffsetModel, + PaginationPageNumberPageSizeModel, +) def _backward_compatible_value_for_like_and_ilike(value: str): @@ -38,8 +42,14 @@ def _backward_compatible_value_for_like_and_ilike(value: str): "isnull": lambda value: ("is_", None) if value is True else ("is_not", None), "lt": lambda value: ("__lt__", value), "lte": lambda value: ("__le__", value), - "like": lambda value: ("like", _backward_compatible_value_for_like_and_ilike(value)), - "ilike": lambda value: ("ilike", _backward_compatible_value_for_like_and_ilike(value)), + "like": lambda value: ( + "like", + _backward_compatible_value_for_like_and_ilike(value), + ), + "ilike": lambda value: ( + "ilike", + _backward_compatible_value_for_like_and_ilike(value), + ), # XXX(arthurio): Mysql excludes None values when using `in` or `not in` filters. "not": lambda value: ("is_not", value), "not_in": lambda value: ("not_in", value), @@ -115,7 +125,9 @@ def filter(self, query: Union[Query, Select]): else: operator = "__eq__" - if field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"): + if field_name == self.Constants.search_field_name and hasattr( + self.Constants, "search_model_fields" + ): search_filters = [ getattr(self.Constants.model, field).ilike(f"%{value}%") for field in self.Constants.search_model_fields @@ -142,3 +154,23 @@ def sort(self, query: Union[Query, Select]): query = query.order_by(getattr(order_by_field, direction)()) return query + + def paginate(self, query: Union[Query, Select]): + pagination = self.pagination_field_model_value + if not pagination: + return query + if isinstance(pagination, PaginationLimitOffsetModel): + limit_field = pagination.limit_field + offset_field = pagination.offset_field + limit = getattr(self, limit_field) + offset = getattr(self, offset_field) + elif isinstance(pagination, PaginationPageNumberPageSizeModel): + page_field = pagination.page_field + size_field = pagination.size_field + + page = getattr(self, page_field) + size = getattr(self, size_field) + offset = (page - 1) * size + limit = size + query = query.offset(offset).limit(limit) + return query