From ff3af20b504379d3200bc8d76d93fb801c4c2fef Mon Sep 17 00:00:00 2001 From: Mohammed Hamid Date: Wed, 5 Mar 2025 06:19:16 -0800 Subject: [PATCH 1/2] drop support for comma separated strings in favor is native lists --- fastapi_filter/base/filter.py | 34 +++++++------------- fastapi_filter/contrib/beanie/filter.py | 23 +------------ fastapi_filter/contrib/mongoengine/filter.py | 18 ----------- fastapi_filter/contrib/sqlalchemy/filter.py | 18 ----------- 4 files changed, 12 insertions(+), 81 deletions(-) diff --git a/fastapi_filter/base/filter.py b/fastapi_filter/base/filter.py index ae00e774..94fe1469 100644 --- a/fastapi_filter/base/filter.py +++ b/fastapi_filter/base/filter.py @@ -1,10 +1,9 @@ import sys from collections import defaultdict -from collections.abc import Iterable from copy import deepcopy from typing import Any, Optional, Union, get_args, get_origin -from fastapi import Depends +from fastapi import Depends, Query from fastapi.exceptions import RequestValidationError from pydantic import BaseModel, ConfigDict, ValidationError, ValidationInfo, create_model, field_validator from pydantic.fields import FieldInfo @@ -183,31 +182,20 @@ class Constants(Filter.Constants): # type: ignore[name-defined] return NestedFilter -def _list_to_str_fields(Filter: type[BaseFilterModel]): +def _list_to_query_fields(Filter: type[BaseFilterModel]): ret: dict[str, tuple[Union[object, type], Optional[FieldInfo]]] = {} for name, f in Filter.model_fields.items(): field_info = deepcopy(f) annotation = f.annotation - if get_origin(annotation) in UNION_TYPES: - annotation_args: list = list(get_args(f.annotation)) - if type(None) in annotation_args: - annotation_args.remove(type(None)) - if len(annotation_args) == 1: - annotation = annotation_args[0] - # NOTE: This doesn't support union types which contain list and other types at the - # same time like `list[str] | str` or `list[str] | str | None`. The list type inside - # union will not be converted to string which means that the filter will not work in - # such cases. - # We cannot raise exception here because we still want to support union types in - # filter for example `int | float | None` is valid type and should not be transformed. - - if annotation is list or get_origin(annotation) is list: - if isinstance(field_info.default, Iterable): - field_info.default = ",".join(field_info.default) - ret[name] = (str if f.is_required() else Optional[str], field_info) - else: - ret[name] = (f.annotation, field_info) + if ( + annotation is list + or get_origin(annotation) is list + or any(get_origin(a) is list for a in get_args(annotation)) + ): + field_info.default = Query(default=field_info.default) + + ret[name] = (f.annotation, field_info) return ret @@ -223,7 +211,7 @@ def FilterDepends(Filter: type[BaseFilterModel], *, by_alias: bool = False, use_ When we apply the filter, we build the original filter to properly validate the data (i.e. can the string be parsed and formatted as a list of ?) """ - fields = _list_to_str_fields(Filter) + fields = _list_to_query_fields(Filter) GeneratedFilter: type[BaseFilterModel] = create_model(Filter.__class__.__name__, **fields) class FilterWrapper(GeneratedFilter): # type: ignore[misc,valid-type] diff --git a/fastapi_filter/contrib/beanie/filter.py b/fastapi_filter/contrib/beanie/filter.py index a17ef275..384c5d3c 100644 --- a/fastapi_filter/contrib/beanie/filter.py +++ b/fastapi_filter/contrib/beanie/filter.py @@ -1,9 +1,8 @@ from collections.abc import Callable, Mapping -from typing import Any, Optional, Union +from typing import Any, Optional from beanie.odm.interfaces.find import FindType from beanie.odm.queries.find import FindMany -from pydantic import ValidationInfo, field_validator from fastapi_filter.base.filter import BaseFilterModel @@ -53,26 +52,6 @@ def sort(self, query: FindMany[FindType]) -> FindMany[FindType]: return query return query.sort(*self.ordering_values) - @field_validator("*", mode="before") - @classmethod - def split_str( - cls: type["BaseFilterModel"], value: Optional[str], field: ValidationInfo - ) -> Optional[Union[list[str], str]]: - if ( - field.field_name is not None - and ( - field.field_name == cls.Constants.ordering_field_name - or field.field_name.endswith("__in") - or field.field_name.endswith("__nin") - ) - and isinstance(value, str) - ): - if not value: - # Empty string should return [] not [''] - return [] - return list(value.split(",")) - return value - def _get_filter_conditions(self, nesting_depth: int = 1) -> list[tuple[Mapping[str, Any], Mapping[str, Any]]]: filter_conditions: list[tuple[Mapping[str, Any], Mapping[str, Any]]] = [] for field_name, value in self.filtering_fields: diff --git a/fastapi_filter/contrib/mongoengine/filter.py b/fastapi_filter/contrib/mongoengine/filter.py index c67cf801..01b5ba7e 100644 --- a/fastapi_filter/contrib/mongoengine/filter.py +++ b/fastapi_filter/contrib/mongoengine/filter.py @@ -1,6 +1,5 @@ from mongoengine import QuerySet from mongoengine.queryset.visitor import Q -from pydantic import ValidationInfo, field_validator from ...base.filter import BaseFilterModel @@ -33,23 +32,6 @@ def sort(self, query: QuerySet) -> QuerySet: return query return query.order_by(*self.ordering_values) - @field_validator("*", mode="before") - def split_str(cls, value, field: ValidationInfo): - if ( - field.field_name is not None - and ( - field.field_name == cls.Constants.ordering_field_name - or field.field_name.endswith("__in") - or field.field_name.endswith("__nin") - ) - and isinstance(value, str) - ): - if not value: - # Empty string should return [] not [''] - return [] - return list(value.split(",")) - return value - def filter(self, query: QuerySet) -> QuerySet: for field_name, value in self.filtering_fields: field_value = getattr(self, field_name) diff --git a/fastapi_filter/contrib/sqlalchemy/filter.py b/fastapi_filter/contrib/sqlalchemy/filter.py index 9d5d4d99..05ea3dff 100644 --- a/fastapi_filter/contrib/sqlalchemy/filter.py +++ b/fastapi_filter/contrib/sqlalchemy/filter.py @@ -2,7 +2,6 @@ from typing import Union from warnings import warn -from pydantic import ValidationInfo, field_validator from sqlalchemy import or_ from sqlalchemy.orm import Query from sqlalchemy.sql.selectable import Select @@ -86,23 +85,6 @@ class Direction(str, Enum): asc = "asc" desc = "desc" - @field_validator("*", mode="before") - def split_str(cls, value, field: ValidationInfo): - if ( - field.field_name is not None - and ( - field.field_name == cls.Constants.ordering_field_name - or field.field_name.endswith("__in") - or field.field_name.endswith("__not_in") - ) - and isinstance(value, str) - ): - if not value: - # Empty string should return [] not [''] - return [] - return list(value.split(",")) - return value - def filter(self, query: Union[Query, Select]): for field_name, value in self.filtering_fields: field_value = getattr(self, field_name) From 704992e0c2377b87aebf97d85fc729df67a146d0 Mon Sep 17 00:00:00 2001 From: Mohammed Hamid Date: Wed, 5 Mar 2025 07:20:31 -0800 Subject: [PATCH 2/2] retain default value for fields already assigned to Query --- fastapi_filter/base/filter.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/fastapi_filter/base/filter.py b/fastapi_filter/base/filter.py index 94fe1469..5399c3aa 100644 --- a/fastapi_filter/base/filter.py +++ b/fastapi_filter/base/filter.py @@ -3,7 +3,7 @@ from copy import deepcopy from typing import Any, Optional, Union, get_args, get_origin -from fastapi import Depends, Query +from fastapi import Depends, Query, params from fastapi.exceptions import RequestValidationError from pydantic import BaseModel, ConfigDict, ValidationError, ValidationInfo, create_model, field_validator from pydantic.fields import FieldInfo @@ -192,7 +192,7 @@ def _list_to_query_fields(Filter: type[BaseFilterModel]): annotation is list or get_origin(annotation) is list or any(get_origin(a) is list for a in get_args(annotation)) - ): + ) and type(field_info.default) is not params.Query: field_info.default = Query(default=field_info.default) ret[name] = (f.annotation, field_info) @@ -201,12 +201,10 @@ def _list_to_query_fields(Filter: type[BaseFilterModel]): def FilterDepends(Filter: type[BaseFilterModel], *, by_alias: bool = False, use_cache: bool = True) -> Any: - """Use a hack to support lists in filters. + """Use a hack to treat lists as query parameters. - FastAPI doesn't support it yet: https://github.com/tiangolo/fastapi/issues/50 - - What we do is loop through the fields of a filter and change any `list` field to a `str` one so that it won't be - excluded from the possible query parameters. + What we do is loop through the fields of a filter and assign any `list` field a default value of `Query` so that + FastAPI knows it should be treated a query parameter and not body. When we apply the filter, we build the original filter to properly validate the data (i.e. can the string be parsed and formatted as a list of ?)