Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ jobs:
Test:
runs-on: ubuntu-latest
strategy:
max-parallel: 4
max-parallel: 5
fail-fast: false
matrix:
python-version: [ 3.7, 3.8, 3.9 ]
python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ]

steps:
- uses: actions/checkout@v3
Expand Down
3 changes: 0 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,9 @@ djangorestframework==3.12.1
entrypoints==0.3
flake8==3.7.7
mccabe==0.6.1
mypy==0.910
mypy-extensions==0.4.3
pycodestyle==2.5.0
pyflakes==2.1.1
pytz==2019.1
sqlparse==0.3.0
toml==0.10.0
typed-ast==1.4.3
typing-extensions==3.10.0.0
25 changes: 22 additions & 3 deletions rest_flex_fields/filter_backends.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools
from collections.abc import Sequence
from functools import lru_cache
from typing import Optional

Expand All @@ -9,6 +11,7 @@
from rest_framework.request import Request
from rest_framework.viewsets import GenericViewSet

from rest_flex_fields.utils import get_model_from_dot_path
from rest_flex_fields import (
FIELDS_PARAM,
EXPAND_PARAM,
Expand Down Expand Up @@ -47,10 +50,20 @@ def _get_expandable_fields(serializer_class: FlexFieldsModelSerializer) -> list:
expand_list = []
while expandable_fields:
key, cls = expandable_fields.pop()
cls = cls[0] if hasattr(cls, '__iter__') else cls
cls = cls[0] if isinstance(cls, Sequence) else cls

try:
cls = get_model_from_dot_path(cls) if isinstance(cls, str) else cls
except (ValueError, AttributeError):
pass

skip_next_level = key.rsplit('.', 1)[-1] in list(itertools.chain.from_iterable([i.split('.') for i in expand_list]))
expand_list.append(key)

# Skip node, already visited
if skip_next_level:
continue

if hasattr(cls, "Meta") and issubclass(cls, FlexFieldsSerializerMixin) and hasattr(cls.Meta, "expandable_fields"):
next_layer = getattr(cls.Meta, 'expandable_fields')
expandable_fields.extend([(f"{key}.{k}", cls) for k, cls in list(next_layer.items())])
Expand Down Expand Up @@ -227,13 +240,19 @@ def filter_queryset(

queryset = queryset.prefetch_related(*(
model_field.name for model_field in nested_model_fields if
(model_field.is_relation and not model_field.many_to_one) or
(model_field.is_relation and model_field.many_to_one and not model_field.concrete) # Include GenericForeignKey)
(
(model_field.is_relation and not model_field.many_to_one) or
(model_field.is_relation and model_field.many_to_one and not model_field.concrete) # Include GenericForeignKey)
) and
(model_field.name not in self._get_prefetches_ignores())
)
)

return queryset

def _get_prefetches_ignores(self):
return []

@staticmethod
@lru_cache()
def _get_field(field_name: str, model: models.Model) -> Optional[models.Field]:
Expand Down
13 changes: 13 additions & 0 deletions rest_flex_fields/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import logging
import importlib
from collections.abc import Iterable
from django.db.models import Model
from typing import Optional

from rest_flex_fields import EXPAND_PARAM, FIELDS_PARAM, OMIT_PARAM, WILDCARD_VALUES

Expand Down Expand Up @@ -72,3 +76,12 @@ def split_levels(fields):

first_level_fields = list(set(first_level_fields))
return first_level_fields, next_level_fields


def get_model_from_dot_path(dot_path: str) -> Optional[Model]:
"""Given a dot path such as 'testapp.models.Person', return the model class."""
module_path, attribute_name = dot_path.rsplit('.', 1)
module = importlib.import_module(module_path)
model = getattr(module, attribute_name)

return model