diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d7c4f5f..260c068 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -8,9 +8,10 @@ jobs: Test: runs-on: ubuntu-latest strategy: - max-parallel: 4 + max-parallel: 5 + fail-fast: false matrix: - python-version: [ 3.7, 3.8, 3.9 ] + python-version: [ "3.7", "3.8", "3.9", "3.10", "3.11" ] steps: - uses: actions/checkout@v3 diff --git a/requirements.txt b/requirements.txt index 73853c2..2089838 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,12 +8,9 @@ djangorestframework==3.12.1 entrypoints==0.3 flake8==3.7.7 mccabe==0.6.1 -mypy==0.910 -mypy-extensions==0.4.3 pycodestyle==2.5.0 pyflakes==2.1.1 pytz==2019.1 sqlparse==0.3.0 toml==0.10.0 -typed-ast==1.4.3 typing-extensions==3.10.0.0 diff --git a/rest_flex_fields/filter_backends.py b/rest_flex_fields/filter_backends.py index 63e3402..ca4fd93 100644 --- a/rest_flex_fields/filter_backends.py +++ b/rest_flex_fields/filter_backends.py @@ -1,3 +1,5 @@ +import itertools +from collections.abc import Sequence from functools import lru_cache from typing import Optional @@ -9,6 +11,7 @@ from rest_framework.request import Request from rest_framework.viewsets import GenericViewSet +from rest_flex_fields.utils import get_model_from_dot_path from rest_flex_fields import ( FIELDS_PARAM, EXPAND_PARAM, @@ -47,10 +50,20 @@ def _get_expandable_fields(serializer_class: FlexFieldsModelSerializer) -> list: expand_list = [] while expandable_fields: key, cls = expandable_fields.pop() - cls = cls[0] if hasattr(cls, '__iter__') else cls + cls = cls[0] if isinstance(cls, Sequence) else cls + try: + cls = get_model_from_dot_path(cls) if isinstance(cls, str) else cls + except (ValueError, AttributeError): + pass + + skip_next_level = key.rsplit('.', 1)[-1] in list(itertools.chain.from_iterable([i.split('.') for i in expand_list])) expand_list.append(key) + # Skip node, already visited + if skip_next_level: + continue + if hasattr(cls, "Meta") and issubclass(cls, FlexFieldsSerializerMixin) and hasattr(cls.Meta, "expandable_fields"): next_layer = getattr(cls.Meta, 'expandable_fields') expandable_fields.extend([(f"{key}.{k}", cls) for k, cls in list(next_layer.items())]) @@ -227,13 +240,19 @@ def filter_queryset( queryset = queryset.prefetch_related(*( model_field.name for model_field in nested_model_fields if - (model_field.is_relation and not model_field.many_to_one) or - (model_field.is_relation and model_field.many_to_one and not model_field.concrete) # Include GenericForeignKey) + ( + (model_field.is_relation and not model_field.many_to_one) or + (model_field.is_relation and model_field.many_to_one and not model_field.concrete) # Include GenericForeignKey) + ) and + (model_field.name not in self._get_prefetches_ignores()) ) ) return queryset + def _get_prefetches_ignores(self): + return [] + @staticmethod @lru_cache() def _get_field(field_name: str, model: models.Model) -> Optional[models.Field]: diff --git a/rest_flex_fields/utils.py b/rest_flex_fields/utils.py index 3f2c49a..c4af147 100644 --- a/rest_flex_fields/utils.py +++ b/rest_flex_fields/utils.py @@ -1,4 +1,8 @@ +import logging +import importlib from collections.abc import Iterable +from django.db.models import Model +from typing import Optional from rest_flex_fields import EXPAND_PARAM, FIELDS_PARAM, OMIT_PARAM, WILDCARD_VALUES @@ -72,3 +76,12 @@ def split_levels(fields): first_level_fields = list(set(first_level_fields)) return first_level_fields, next_level_fields + + +def get_model_from_dot_path(dot_path: str) -> Optional[Model]: + """Given a dot path such as 'testapp.models.Person', return the model class.""" + module_path, attribute_name = dot_path.rsplit('.', 1) + module = importlib.import_module(module_path) + model = getattr(module, attribute_name) + + return model \ No newline at end of file