Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions fastapi_querybuilder/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .dependencies import QueryBuilder
from .params import QueryParams
from .builder import build_query
from .dependencies import QueryBuilder
from .fields import SchemaConfig
from .params import QueryParams, QueryParamsConfigDict

__all__ = ["QueryBuilder", "QueryParams", "build_query"]
__all__ = ["QueryBuilder", "QueryParams", "build_query", "SchemaConfig", "QueryParamsConfigDict"]
62 changes: 31 additions & 31 deletions fastapi_querybuilder/builder.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
from sqlalchemy import cast, select, or_, asc, desc, String, Enum
from fastapi import HTTPException
from .core import parse_filter_query, parse_filters, resolve_and_join_column
from sqlalchemy import Enum, String, asc, cast, desc, or_, select
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.sql.elements import KeyedColumnElement

from .core import parse_filters, resolve_and_join_column
from .params import QueryParams

def build_query(cls, params):
if hasattr(cls, 'deleted_at'):

def build_query(cls: type[DeclarativeBase], params: QueryParams):
if hasattr(cls, "deleted_at"):
query = select(cls).where(cls.deleted_at.is_(None))
else:
query = select(cls)

# Filters
parsed_filters = parse_filter_query(params.filters)
if parsed_filters:
filter_expr, query = parse_filters(cls, parsed_filters, query)
if params.filters:
filter_expr, query = parse_filters(cls, params.filters, query, params.get_schema().filterable_fields)
if filter_expr is not None:
query = query.where(filter_expr)

# Search - ONLY in safe columns
# Search - restricted to schema.searchable_fields
if params.search:
search_expr = []

allowed_search = set(params.get_schema().searchable_fields)

for column in cls.__table__.columns:
if column.key not in allowed_search:
continue
if is_enum_column(column):
search_expr.append(cast(column, String).ilike(f"%{params.search}%"))
elif is_string_column(column):
Expand All @@ -37,42 +43,36 @@ def build_query(cls, params):

# Sorting
if params.sort:
try:
sort_field, sort_dir = params.sort.split(":")
except ValueError:
sort_field, sort_dir = params.sort, "asc"

column = getattr(cls, sort_field, None)
if column is None:
nested_keys = sort_field.split(".")
if len(nested_keys) > 1:
joins = {}
column, query = resolve_and_join_column(
cls, nested_keys, query, joins)
else:
raise HTTPException(
status_code=400, detail=f"Invalid sort field: {sort_field}")

query = query.order_by(
asc(column) if sort_dir.lower() == "asc" else desc(column))
for sort_field in params.sort:
column = getattr(cls, sort_field.field, None)
if column is None:
nested_keys = sort_field.field.split(".")
if len(nested_keys) > 1:
joins = {}
column, query = resolve_and_join_column(cls, nested_keys, query, joins)
else:
raise HTTPException(status_code=400, detail=f"Invalid sort field: {sort_field}")

query = query.order_by(asc(column) if sort_field.direction.lower() == "asc" else desc(column))

return query

def is_enum_column(column):

def is_enum_column(column: KeyedColumnElement):
"""Check if a column is an enum type"""
return isinstance(column.type, Enum)


def is_string_column(column):
def is_string_column(column: KeyedColumnElement):
"""Check if a column is a string type"""
return isinstance(column.type, String)


def is_integer_column(column):
def is_integer_column(column: KeyedColumnElement):
"""Check if a column is an integer type"""
return hasattr(column.type, "python_type") and column.type.python_type is int


def is_boolean_column(column):
def is_boolean_column(column: KeyedColumnElement):
"""Check if a column is a boolean type"""
return hasattr(column.type, "python_type") and column.type.python_type is bool
71 changes: 26 additions & 45 deletions fastapi_querybuilder/core.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
# app/filters/core.py

from typing import Any, Optional

from fastapi import HTTPException
from sqlalchemy.orm import RelationshipProperty, aliased
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute, RelationshipProperty, aliased
from sqlalchemy.sql import Select, and_
from typing import Any, Optional, Dict, Tuple
import json
from .operators import LOGICAL_OPERATORS, COMPARISON_OPERATORS

def resolve_and_join_column(model, nested_keys: list[str], query: Select, joins: dict) -> Tuple[Any, Select]:
from .operators import COMPARISON_OPERATORS, LOGICAL_OPERATORS, Operator
from .params import FilterSchema


def resolve_and_join_column(
model: DeclarativeBase, nested_keys: list[str], query: Select, joins: dict
) -> tuple[InstrumentedAttribute[Any], Select]:
current_model = model
alias = None

for i, attr in enumerate(nested_keys):
relationship = getattr(current_model, attr, None)
relationship: InstrumentedAttribute[Any] = getattr(current_model, attr, None)

if relationship is not None and isinstance(relationship.property, RelationshipProperty):
related_model = relationship.property.mapper.class_
Expand All @@ -30,27 +35,22 @@ def resolve_and_join_column(model, nested_keys: list[str], query: Select, joins:
raise HTTPException(
status_code=400,
detail=f"Invalid filter key: {'.'.join(nested_keys)}. "
f"Could not resolve attribute '{attr}' in model '{current_model.__name__}'."
f"Could not resolve attribute '{attr}' in model '{current_model.__name__}'.",
)
raise HTTPException(
status_code=400,
detail=f"Could not resolve relationship for {'.'.join(nested_keys)}."
)
raise HTTPException(status_code=400, detail=f"Could not resolve relationship for {'.'.join(nested_keys)}.")


def parse_filters(model, filters: dict, query: Select) -> Tuple[Optional[Any], Select]:
def parse_filters(
model, filters: FilterSchema, query: Select, allowed_fields: set[str]
) -> tuple[Optional[Any], Select]:
expressions = []
joins = {}

if not isinstance(filters, dict):
raise HTTPException(
status_code=400, detail="Filters must be a dictionary")

for key, value in filters.items():
filter_dict = filters.root
for key, value in filter_dict.items():
if key in LOGICAL_OPERATORS:
if not isinstance(value, list):
raise HTTPException(
status_code=400, detail=f"Logical operator '{key}' must be a list")
raise HTTPException(status_code=400, detail=f"Logical operator '{key}' must be a list")
sub_expressions = []
for sub_filter in value:
sub_expr, query = parse_filters(model, sub_filter, query)
Expand All @@ -61,37 +61,18 @@ def parse_filters(model, filters: dict, query: Select) -> Tuple[Optional[Any], S

elif isinstance(value, dict):
nested_keys = key.split(".")
column, query = resolve_and_join_column(
model, nested_keys, query, joins)
column, query = resolve_and_join_column(model, nested_keys, query, joins)
for operator, operand in value.items():
if operator not in COMPARISON_OPERATORS:
raise HTTPException(
status_code=400, detail=f"Invalid operator '{operator}' for field '{key}'")
raise HTTPException(status_code=400, detail=f"Invalid operator '{operator}' for field '{key}'")
try:
if operator in ["$isempty", "$isnotempty"]:
expressions.append(
COMPARISON_OPERATORS[operator](column))
if operator in {Operator.ISNOTEMPTY, Operator.ISEMPTY}:
expressions.append(COMPARISON_OPERATORS[operator](column))
else:
expressions.append(
COMPARISON_OPERATORS[operator](column, operand))
expressions.append(COMPARISON_OPERATORS[operator](column, operand))
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Error filtering '{key}': {e}")
raise HTTPException(status_code=400, detail=f"Error filtering '{key}': {e}")
else:
raise HTTPException(
status_code=400, detail=f"Invalid filter format for key '{key}': {value}")
raise HTTPException(status_code=400, detail=f"Invalid filter format for key '{key}': {value}")

return and_(*expressions) if expressions else None, query


def parse_filter_query(filters: Optional[str]) -> Optional[Dict]:
if not filters:
return None
try:
parsed = json.loads(filters)
if not isinstance(parsed, dict):
raise ValueError("Filters must be a JSON object")
return parsed
except Exception as e:
raise HTTPException(
status_code=400, detail=f"Invalid filter JSON: {e}")
56 changes: 48 additions & 8 deletions fastapi_querybuilder/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,55 @@
# fastapi_querybuilder/dependencies.py

from fastapi import Depends, Request
from fastapi import Depends, HTTPException, Query, Request
from pydantic import ValidationError

from .fields import SchemaConfig
from .params import QueryParams

from .builder import build_query
from typing import Type


def QueryBuilder(model: Type):
def wrapper(
request: Request,
params: QueryParams = Depends()
):
return build_query(model, params)
def _parse_errors(e: ValidationError):
# Ensure the error structure is compatible with FastAPI's expected format
return [
{"type": error["type"], "loc": error["loc"], "msg": error["msg"]}
for error in e.errors(include_context=False, include_url=False, include_input=False)
]


def filter_params(cls: type[QueryParams]):
def get(
filters: str | None = Query(
default=None,
description="Filtro en formato de string JSON.",
example='{"$and": [{"active": {"$eq": true}}, {"role": {"$eq": "admin"}}, {"age": {"$gte": 30, "$lt": 50}}]}',
),
sort: str | None = Query(
default=None,
description="e.g. name:asc,phone:desc or user.email:desc",
example="name:asc,age:desc",
),
search: str | None = Query(
default=None,
description="Una cadena para búsqueda global a través de campos de cadena.",
example="developer",
),
) -> type[QueryParams]:
try:
return cls(filters=filters, sort=sort, search=search)
except ValidationError as e:
raise HTTPException(status_code=422, detail=_parse_errors(e))

return get


def QueryBuilder(query_params: type[QueryParams] = QueryParams):
def wrapper(request: Request, params: QueryParams = Depends(filter_params(query_params))):
schema_config: SchemaConfig = query_params.model_config.get("schema_config")
model = query_params.model_config.get("sqla_model")
return build_query(
model or schema_config.model,
params,
)

return Depends(wrapper)
Loading