Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions examples/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
version: "3.3"

services:

elasticsearch:
image: elasticsearch:7.6.2
# restart: always
ports:
- 9200:9200
environment:
- node.name=fastapi-filter-es
- cluster.name=fastapi-filter-es-docker-cluster
- discovery.type=single-node
- bootstrap.memory_lock=true
- "ES_JAVA_OPTS=-Xms512m -Xmx512m"
ulimits:
memlock:
soft: -1
hard: -1
volumes:
- elasticsearch-data:/usr/share/elasticsearch/data

volumes:
elasticsearch-data:
189 changes: 189 additions & 0 deletions examples/fastapi_filter_elasticsearch_dsl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import logging
from typing import Any, List, Optional

import uvicorn
from faker import Faker
from fastapi import FastAPI
from pydantic import BaseModel, ConfigDict, EmailStr

from fastapi_filter import FilterDepends, with_prefix
from fastapi_filter.contrib.elasticsearch_dsl import Filter

fake = Faker()

logger = logging.getLogger("uvicorn")
from datetime import datetime
from fnmatch import fnmatch

from elasticsearch_dsl import Document, Keyword, connections, Integer, Nested, SearchAsYouType, InnerDoc


ALIAS = "address"
PATTERN = ALIAS + "-*"


class Address(InnerDoc):
street = Keyword()
city = SearchAsYouType()
country = Keyword()
number = Integer()


class User(Document):
name = SearchAsYouType()
email = Keyword()
age = Integer()
address = Nested(Address)

@classmethod
def _matches(cls, hit):
return fnmatch(hit["_index"], PATTERN)

class Index:
name = ALIAS
settings = {"number_of_shards": 1, "number_of_replicas": 0}


def setup():
index_template = User._index.as_template(ALIAS, PATTERN)
index_template.save()

if not User._index.exists():
migrate(move_data=False)


def migrate(move_data=True, update_alias=True):
# construct a new index name by appending current timestamp
next_index = PATTERN.replace("*", datetime.now().strftime("%Y%m%d%H%M%S%f"))
es = connections.get_connection()
# create new index, it will use the settings from the template
es.indices.create(index=next_index)
if move_data:
# move data from current alias to the new index
es.reindex(
body={"source": {"index": ALIAS}, "dest": {"index": next_index}},
request_timeout=3600,
)
# refresh the index to make the changes visible
es.indices.refresh(index=next_index)

if update_alias:
# repoint the alias to point to the newly created index
es.indices.update_aliases(
body={
"actions": [
{"remove": {"alias": ALIAS, "index": PATTERN}},
{"add": {"alias": ALIAS, "index": next_index}},
]
}
)


class AddressOut(BaseModel):
street: Optional[str] = None
city: str
number: int
country: str

class Config:
orm_mode = True


class UserIn(BaseModel):
name: str
email: EmailStr
age: int


class UserOut(UserIn):
model_config = ConfigDict(from_attributes=True)

name: str
email: EmailStr
age: int
address: Optional[AddressOut] = None


class AddressFilter(Filter):
street: Optional[str] = None
number: Optional[int] = None
number__gt: Optional[int] = None
number__gte: Optional[int] = None
number__lt: Optional[int] = None
number__lte: Optional[int] = None
street__isnull: Optional[bool] = None
country: Optional[str] = None
country_not: Optional[str] = None
city: Optional[str] = None
city__in: Optional[List[str]] = None
city__not_in: Optional[List[str]] = ["city"]
custom_order_by: Optional[List[str]] = None
custom_search: Optional[str] = None
order_by: List[str] = ["-street"]

class Constants(Filter.Constants):
model = Address
# ordering_field_name = "street"
search_field_name = "custom_search"
search_model_fields = ["street", "country", "city"]


class UserFilter(Filter):
name: Optional[str] = None
address: Optional[AddressFilter] = FilterDepends(with_prefix("address", AddressFilter))
age__lt: Optional[int] = None
# age__gte: int = Field(Query(description="this is a nice description"))
"""Required field with a custom description.

See: https://github.com/tiangolo/fastapi/issues/4700 for why we need to wrap `Query` in `Field`.
"""
order_by: List[str] = ["-age"]
search: Optional[str] = None

class Constants(Filter.Constants):
model = User
search_model_fields = ["name"]


app = FastAPI()


@app.on_event("startup")
async def on_startup() -> None:
connections.create_connection(hosts="http://localhost:9200")

setup()
migrate()

for i in range(100):
if i % 5 == 0:
address = Address(
street=fake.street_address(),
city=fake.city(),
country=fake.country(),
number=fake.random_int(min=5, max=100),
)
else:
address = Address(city=fake.city(), country=fake.country(), number=fake.random_int(min=5, max=100))
user = User(name=fake.name(), email=fake.email(), age=fake.random_int(min=5, max=120), address=address)
user.save()


@app.on_event("shutdown")
async def on_shutdown() -> None:
s = Address.search().query("match_all")
s.delete()


@app.get("/users", response_model=List[UserOut])
async def get_users(
user_filter: UserFilter = FilterDepends(with_prefix("my_custom_prefix", UserFilter), by_alias=True),
) -> Any:
query = user_filter.filter(User.search())
query = user_filter.sort(query)
response = query.execute()
return [UserOut(**user.to_dict()) for user in response]


if __name__ == "__main__":
uvicorn.run("main:app", reload=True)
26 changes: 26 additions & 0 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
annotated-types==0.6.0
anyio==4.3.0
bson==0.5.10
certifi==2024.2.2
click==8.1.7
dnspython==2.6.1
elastic-transport==8.12.0
elasticsearch==7.17.9
elasticsearch-dsl==7.4.1
email-validator==2.1.0.post1
Faker==23.2.1
fastapi==0.109.2
h11==0.14.0
idna==3.6
mongoengine==0.27.0
pydantic==2.6.2
pydantic_core==2.16.3
pymongo==4.6.2
python-dateutil==2.8.2
six==1.16.0
sniffio==1.3.0
starlette==0.36.3
typing_extensions==4.9.0
urllib3==1.26.18
uvicorn==0.27.1

3 changes: 3 additions & 0 deletions fastapi_filter/contrib/elasticsearch_dsl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .filter import Filter

__all__ = ("Filter",)
110 changes: 110 additions & 0 deletions fastapi_filter/contrib/elasticsearch_dsl/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
from elasticsearch_dsl import Q, Search
from elasticsearch_dsl.query import Query
from pydantic import ValidationInfo, field_validator

from ...base.filter import BaseFilterModel


_operator_transformer = {
"neq": lambda value, field_name: ~Q("term", **{field_name: value}),
"gt": lambda value, field_name: Q("range", **{field_name: {"gt": value}}),
"gte": lambda value, field_name: Q("range", **{field_name: {"gte": value}}),
"lt": lambda value, field_name: Q("range", **{field_name: {"lt": value}}),
"lte": lambda value, field_name: Q("range", **{field_name: {"lte": value}}),
"in": lambda value, field_name: Q("terms", **{field_name: value}),
"isnull": lambda value, field_name: ~Q("exists", field=field_name)
if value is True
else Q("exists", field=field_name),
"not": lambda value, field_name: ~Q("term", **{field_name: value}),
"not_in": lambda value, field_name: ~Q("terms", **{field_name: value}),
"nin": lambda value, field_name: ~Q("terms", **{field_name: value}),
}


class Filter(BaseFilterModel):
"""Base filter for elasticsearch_dsl related filters.

Example:
```python

class MyModel(Document):
street = Keyword()
city = Keyword()
country = Keyword()
number = Integer()

class MyModelFilter(Filter):
street: Optional[str] = None
number: Optional[int] = None
number__gt: Optional[int] = None
number__gte: Optional[int] = None
number__lt: Optional[int] = None
number__lte: Optional[int] = None
street__isnull: Optional[bool] = None
country: Optional[str] = None
country_not: Optional[str] = None
city: Optional[str] = None
city__in: Optional[List[str]] = None
city__not_in: Optional[List[str]] = ["city"]
custom_order_by: Optional[List[str]] = None
custom_search: Optional[str] = None
order_by: List[str] = ["-street"]
```
"""

def sort(self, query: Search) -> Search:
if not self.ordering_values:
return query
return query.sort(*self.ordering_values)

@field_validator("*", mode="before")
def split_str(cls, value, field: ValidationInfo):
if (
field.field_name is not None
and (
field.field_name == cls.Constants.ordering_field_name
or field.field_name.endswith("__in")
or field.field_name.endswith("__nin")
or field.field_name.endswith("__not_in")
)
and isinstance(value, str)
):
if not value:
# Empty string should return [] not ['']
return []
return list(value.split(","))
return value

def make_query(self, field_name: str, value) -> Query:
if "__" in field_name:
field_name, operator = field_name.split("__")
query = _operator_transformer[operator](value, field_name)
elif field_name == self.Constants.search_field_name and hasattr(self.Constants, "search_model_fields"):
query = Q(
"multi_match",
type="bool_prefix",
fields=[
field_gram
for field in self.Constants.search_model_fields
for field_gram in [f"{field}", f"{field}._2gram", f"{field}._3gram"]
],
query=value,
)
else:
query = Q("term", **{field_name: value})
return query

def filter(self, search: Search) -> Search:
queries = Q()
for field_name, value in self.filtering_fields:
field_value = getattr(self, field_name)
if isinstance(field_value, Filter):
nested_queries = Q()
for inner_field, inner_value in field_value.filtering_fields:
nested_queries &= self.make_query(f"{field_name}.{inner_field}", inner_value)
search.query("nested", path=field_name, query=nested_queries)
else:
queries &= self.make_query(field_name, value)

return search.query(queries)