From 0253c049d3f02b3b39d759adf8b95197c325b6b5 Mon Sep 17 00:00:00 2001 From: Eleftherios Zisis Date: Thu, 3 Jul 2025 14:07:37 +0200 Subject: [PATCH 1/5] Change in/not_in separator to | --- app/filters/base.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/app/filters/base.py b/app/filters/base.py index a7747ee06..9ccbfd385 100644 --- a/app/filters/base.py +++ b/app/filters/base.py @@ -3,7 +3,7 @@ from fastapi_filter.contrib.sqlalchemy import Filter from fastapi_filter.contrib.sqlalchemy.filter import _orm_operator_transformer # noqa: PLC2701 -from pydantic import field_validator +from pydantic import ValidationInfo, field_validator from sqlalchemy import Select, or_ from sqlalchemy.orm import DeclarativeBase @@ -18,6 +18,25 @@ class CustomFilter[T: DeclarativeBase](Filter): class Constants(Filter.Constants): ordering_model_fields: list[str] + @field_validator("*", mode="before") + @classmethod + def split_str(cls, value, field: ValidationInfo): + if ( + field.field_name is not None + and ( + field.field_name == cls.Constants.ordering_field_name + or field.field_name.endswith("__in") + or field.field_name.endswith("__not_in") + ) + and isinstance(value, str) + ): + if not value: + # Empty string should return [] not [''] + return [] + + return list(value.split("|")) + return value + @field_validator("order_by", check_fields=False) @classmethod def restrict_sortable_fields(cls, value: list[str]): From 9db2e1e8ec76d37153df6420f4a991b23eb66cc9 Mon Sep 17 00:00:00 2001 From: Eleftherios Zisis Date: Thu, 3 Jul 2025 20:19:36 +0200 Subject: [PATCH 2/5] Support native list params for filters --- app/dependencies/filter.py | 55 +++++++++++++++++++ app/filters/activity.py | 3 +- app/filters/base.py | 19 +------ app/filters/brain_atlas.py | 3 +- app/filters/brain_region.py | 2 +- app/filters/brain_region_hierarchy.py | 3 +- app/filters/cell_composition.py | 3 +- app/filters/circuit.py | 3 +- app/filters/common.py | 3 +- app/filters/density.py | 3 +- app/filters/electrical_cell_recording.py | 3 +- app/filters/emodel.py | 3 +- app/filters/entity.py | 3 +- app/filters/ion_channel_model.py | 3 +- app/filters/measurement_annotation.py | 3 +- app/filters/memodel.py | 3 +- app/filters/memodel_calibration_result.py | 3 +- app/filters/morphology.py | 3 +- app/filters/person.py | 3 +- app/filters/simulation.py | 3 +- app/filters/simulation_campaign.py | 3 +- app/filters/simulation_execution.py | 3 +- app/filters/simulation_generation.py | 3 +- app/filters/simulation_result.py | 3 +- app/filters/single_neuron_simulation.py | 3 +- app/filters/single_neuron_synaptome.py | 3 +- .../single_neuron_synaptome_simulation.py | 3 +- app/filters/validation_result.py | 3 +- tests/test_circuit.py | 2 +- tests/test_electrical_cell_recording.py | 15 ++++- tests/test_etype.py | 2 +- tests/test_morphology.py | 6 +- tests/test_mtype.py | 2 +- tests/test_simulation_execution.py | 6 +- tests/test_simulation_generation.py | 6 +- 35 files changed, 120 insertions(+), 70 deletions(-) create mode 100644 app/dependencies/filter.py diff --git a/app/dependencies/filter.py b/app/dependencies/filter.py new file mode 100644 index 000000000..66acc4675 --- /dev/null +++ b/app/dependencies/filter.py @@ -0,0 +1,55 @@ +from copy import deepcopy +from typing import Any, get_args, get_origin + +from fastapi import Depends, Query, params +from fastapi.exceptions import RequestValidationError +from fastapi_filter.base.filter import BaseFilterModel +from pydantic import ValidationError, create_model + + +def _list_to_query_fields(filter_model: type[BaseFilterModel]): + fields = {} + for name, f in filter_model.model_fields.items(): + field_info = deepcopy(f) + annotation = f.annotation + + if ( + annotation is list + or get_origin(annotation) is list + or any(get_origin(a) is list for a in get_args(annotation)) + ) and type(field_info.default) is not params.Query: + field_info.default = Query(default=field_info.default) + + fields[name] = (f.annotation, field_info) + + return fields + + +def FilterDepends(filter_model: type[BaseFilterModel], *, by_alias: bool = False, **_) -> Any: + """Use a hack to treat lists as query parameters. + + What we do is loop through the fields of a filter and assign any `list` field a default value of + `Query` so that FastAPI knows it should be treated a query parameter and not body. + + When we apply the filter, we build the original filter to properly validate the data (i.e. can + the string be parsed and formatted as a list of ?) + """ + fields = _list_to_query_fields(filter_model) + GeneratedFilter = create_model(filter_model.__class__.__name__, **fields) # noqa: N806 + + class FilterWrapper(GeneratedFilter): # type: ignore[misc,valid-type] + def __new__(cls, *args, **kwargs): + try: + instance = GeneratedFilter(*args, **kwargs) + data = instance.model_dump( + exclude_unset=True, exclude_defaults=True, by_alias=by_alias + ) + if original_filter := getattr(filter_model.Constants, "original_filter", None): + prefix = f"{filter_model.Constants.prefix}__" + stripped = {k.removeprefix(prefix): v for k, v in data.items()} + return original_filter(**stripped) + return filter_model(**data) + except ValidationError as e: + raise RequestValidationError(e.errors()) from e + + return Depends(FilterWrapper) diff --git a/app/filters/activity.py b/app/filters/activity.py index cf4806e86..8b039a896 100644 --- a/app/filters/activity.py +++ b/app/filters/activity.py @@ -1,8 +1,9 @@ from datetime import datetime from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix +from app.dependencies.filter import FilterDepends from app.filters.common import ( CreationFilterMixin, CreatorFilterMixin, diff --git a/app/filters/base.py b/app/filters/base.py index 9ccbfd385..a2084cd17 100644 --- a/app/filters/base.py +++ b/app/filters/base.py @@ -3,7 +3,7 @@ from fastapi_filter.contrib.sqlalchemy import Filter from fastapi_filter.contrib.sqlalchemy.filter import _orm_operator_transformer # noqa: PLC2701 -from pydantic import ValidationInfo, field_validator +from pydantic import field_validator from sqlalchemy import Select, or_ from sqlalchemy.orm import DeclarativeBase @@ -20,21 +20,8 @@ class Constants(Filter.Constants): @field_validator("*", mode="before") @classmethod - def split_str(cls, value, field: ValidationInfo): - if ( - field.field_name is not None - and ( - field.field_name == cls.Constants.ordering_field_name - or field.field_name.endswith("__in") - or field.field_name.endswith("__not_in") - ) - and isinstance(value, str) - ): - if not value: - # Empty string should return [] not [''] - return [] - - return list(value.split("|")) + def split_str(cls, value, field): # noqa: ARG003 # pyright: ignore reportIncompatibleMethodOverride + """Prevent splitting field logic from parent class.""" return value @field_validator("order_by", check_fields=False) diff --git a/app/filters/brain_atlas.py b/app/filters/brain_atlas.py index 89d51a985..aca331356 100644 --- a/app/filters/brain_atlas.py +++ b/app/filters/brain_atlas.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import BrainAtlas, BrainAtlasRegion +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import IdFilterMixin, NameFilterMixin, SpeciesFilterMixin diff --git a/app/filters/brain_region.py b/app/filters/brain_region.py index 1b0728949..09879b204 100644 --- a/app/filters/brain_region.py +++ b/app/filters/brain_region.py @@ -2,10 +2,10 @@ from typing import Annotated import sqlalchemy as sa -from fastapi_filter import FilterDepends from sqlalchemy.orm import aliased from app.db.model import BrainRegion, BrainRegionHierarchy +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import NameFilterMixin diff --git a/app/filters/brain_region_hierarchy.py b/app/filters/brain_region_hierarchy.py index 33c19f209..0a6bafd68 100644 --- a/app/filters/brain_region_hierarchy.py +++ b/app/filters/brain_region_hierarchy.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import BrainRegionHierarchy +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import IdFilterMixin, NameFilterMixin diff --git a/app/filters/cell_composition.py b/app/filters/cell_composition.py index d9e443f0a..348e01e39 100644 --- a/app/filters/cell_composition.py +++ b/app/filters/cell_composition.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import CellComposition +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( EntityFilterMixin, diff --git a/app/filters/circuit.py b/app/filters/circuit.py index 44eb1afca..27156d350 100644 --- a/app/filters/circuit.py +++ b/app/filters/circuit.py @@ -2,9 +2,8 @@ from datetime import datetime from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import Circuit +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/common.py b/app/filters/common.py index 896d40b75..f46e18f45 100644 --- a/app/filters/common.py +++ b/app/filters/common.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import ( Agent, @@ -15,6 +15,7 @@ Strain, Subject, ) +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter diff --git a/app/filters/density.py b/app/filters/density.py index 0d2a91efa..25fcb1953 100644 --- a/app/filters/density.py +++ b/app/filters/density.py @@ -1,12 +1,11 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import ( ExperimentalBoutonDensity, ExperimentalNeuronDensity, ExperimentalSynapsesPerConnection, ) +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilter, diff --git a/app/filters/electrical_cell_recording.py b/app/filters/electrical_cell_recording.py index e50694008..49405b007 100644 --- a/app/filters/electrical_cell_recording.py +++ b/app/filters/electrical_cell_recording.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import ElectricalCellRecording +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/emodel.py b/app/filters/emodel.py index edcd970e7..0e4422e24 100644 --- a/app/filters/emodel.py +++ b/app/filters/emodel.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import EModel +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/entity.py b/app/filters/entity.py index 066ac67aa..024864b2c 100644 --- a/app/filters/entity.py +++ b/app/filters/entity.py @@ -1,9 +1,8 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import Entity from app.db.types import EntityType +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter diff --git a/app/filters/ion_channel_model.py b/app/filters/ion_channel_model.py index 682072f80..310f166a7 100644 --- a/app/filters/ion_channel_model.py +++ b/app/filters/ion_channel_model.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import IonChannelModel +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/measurement_annotation.py b/app/filters/measurement_annotation.py index 97d4c4613..e55d46e6c 100644 --- a/app/filters/measurement_annotation.py +++ b/app/filters/measurement_annotation.py @@ -1,11 +1,12 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import MeasurementAnnotation, MeasurementItem, MeasurementKind from app.db.types import MeasurementStatistic, MeasurementUnit, StructuralDomain from app.db.utils import MeasurableEntityType +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import CreationFilterMixin diff --git a/app/filters/memodel.py b/app/filters/memodel.py index bdc0089d9..5594f1d40 100644 --- a/app/filters/memodel.py +++ b/app/filters/memodel.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import MEModel, ValidationStatus +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/memodel_calibration_result.py b/app/filters/memodel_calibration_result.py index 3fc629746..0358c0063 100644 --- a/app/filters/memodel_calibration_result.py +++ b/app/filters/memodel_calibration_result.py @@ -1,9 +1,8 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import MEModelCalibrationResult +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import EntityFilterMixin diff --git a/app/filters/morphology.py b/app/filters/morphology.py index 2853486e3..a4ffd947d 100644 --- a/app/filters/morphology.py +++ b/app/filters/morphology.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import ReconstructionMorphology +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/person.py b/app/filters/person.py index 04278d94b..7a82594fa 100644 --- a/app/filters/person.py +++ b/app/filters/person.py @@ -1,9 +1,8 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import Person +from app.dependencies.filter import FilterDepends from app.filters.common import AgentFilter, CreatorFilterMixin diff --git a/app/filters/simulation.py b/app/filters/simulation.py index e9fe2197b..805e6871f 100644 --- a/app/filters/simulation.py +++ b/app/filters/simulation.py @@ -1,9 +1,10 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import Simulation +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( ContributionFilterMixin, diff --git a/app/filters/simulation_campaign.py b/app/filters/simulation_campaign.py index c49298f8a..f8f93f1bc 100644 --- a/app/filters/simulation_campaign.py +++ b/app/filters/simulation_campaign.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SimulationCampaign +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import EntityFilterMixin, NameFilterMixin from app.filters.simulation import NestedSimulationFilter, NestedSimulationFilterDep diff --git a/app/filters/simulation_execution.py b/app/filters/simulation_execution.py index aacfb39a7..9dc6f19a9 100644 --- a/app/filters/simulation_execution.py +++ b/app/filters/simulation_execution.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SimulationExecution +from app.dependencies.filter import FilterDepends from app.filters.activity import ActivityFilterMixin from app.filters.base import CustomFilter diff --git a/app/filters/simulation_generation.py b/app/filters/simulation_generation.py index 1dddd40ba..bb6d3da72 100644 --- a/app/filters/simulation_generation.py +++ b/app/filters/simulation_generation.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SimulationGeneration +from app.dependencies.filter import FilterDepends from app.filters.activity import ActivityFilterMixin from app.filters.base import CustomFilter diff --git a/app/filters/simulation_result.py b/app/filters/simulation_result.py index d9f478e24..66a1f563c 100644 --- a/app/filters/simulation_result.py +++ b/app/filters/simulation_result.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import SimulationResult +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( ContributionFilterMixin, diff --git a/app/filters/single_neuron_simulation.py b/app/filters/single_neuron_simulation.py index e441a3df1..d37ab3e14 100644 --- a/app/filters/single_neuron_simulation.py +++ b/app/filters/single_neuron_simulation.py @@ -1,9 +1,8 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SingleNeuronSimulation from app.db.types import SingleNeuronSimulationStatus +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/single_neuron_synaptome.py b/app/filters/single_neuron_synaptome.py index e70c343bc..2335a01a7 100644 --- a/app/filters/single_neuron_synaptome.py +++ b/app/filters/single_neuron_synaptome.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import SingleNeuronSynaptome +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/single_neuron_synaptome_simulation.py b/app/filters/single_neuron_synaptome_simulation.py index cf5e66694..a77cde781 100644 --- a/app/filters/single_neuron_synaptome_simulation.py +++ b/app/filters/single_neuron_synaptome_simulation.py @@ -1,9 +1,8 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SingleNeuronSynaptomeSimulation from app.db.types import SingleNeuronSimulationStatus +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/validation_result.py b/app/filters/validation_result.py index 632fcd61e..86a5a33e2 100644 --- a/app/filters/validation_result.py +++ b/app/filters/validation_result.py @@ -1,9 +1,8 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import ValidationResult +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import EntityFilterMixin, NameFilterMixin diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 7279efb96..5ee987ad9 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -181,7 +181,7 @@ def test_filtering(client, root_circuit, models): url=ROUTE, params={ "root_circuit_id": str(root_circuit.id), - "scale__in": "single,whole_brain", + "scale__in": ["single", "whole_brain"], }, ).json()["data"] assert len(data) == 4 diff --git a/tests/test_electrical_cell_recording.py b/tests/test_electrical_cell_recording.py index 8860bf46c..404aeeac3 100644 --- a/tests/test_electrical_cell_recording.py +++ b/tests/test_electrical_cell_recording.py @@ -211,7 +211,7 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): db, [ Subject( - name=f"my-subject-{i}", + name=f"subject-{i}", description="my-description", species_id=sp.id, strain_id=None, @@ -244,7 +244,6 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): for i, subject in enumerate(subjects) ], ) - data = assert_request(client.get, url=ROUTE).json()["data"] assert len(data) == len(models) @@ -257,3 +256,15 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): "data" ] assert len(data) == 2 + + data = assert_request(client.get, url=ROUTE, params="subject__species__name=species-2").json()[ + "data" + ] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={"name__in": ["e-1", "e-2"]}, + ).json()["data"] + assert [d["name"] for d in data] == ["e-1", "e-2"] diff --git a/tests/test_etype.py b/tests/test_etype.py index 0848e27ab..4c6662608 100644 --- a/tests/test_etype.py +++ b/tests/test_etype.py @@ -45,7 +45,7 @@ def test_retrieve(db, client, person_id): assert data[0] == with_creation_fields(items[5]) # test filter (in) - response = client.get(ROUTE, params={"pref_label__in": "pref_label_5,pref_label_6"}) + response = client.get(ROUTE, params={"pref_label__in": ["pref_label_5", "pref_label_6"]}) assert response.status_code == 200 data = response.json()["data"] assert len(data) == 2 diff --git a/tests/test_morphology.py b/tests/test_morphology.py index b96e5d5d7..0a1107484 100644 --- a/tests/test_morphology.py +++ b/tests/test_morphology.py @@ -582,7 +582,7 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): # filtering by multiple IDs selected_ids = [morphology_ids[1], morphology_ids[3]] - response = client.get(ROUTE, params={"id__in": ",".join(selected_ids)}) + response = client.get(ROUTE, params={"id__in": selected_ids}) assert response.status_code == 200 data = response.json()["data"] assert len(data) == 2 @@ -590,7 +590,7 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): assert set(returned_ids) == set(selected_ids) # filtering by all IDs - response = client.get(ROUTE, params={"id__in": ",".join(morphology_ids)}) + response = client.get(ROUTE, params={"id__in": morphology_ids}) assert response.status_code == 200 data = response.json()["data"] assert len(data) == 5 @@ -606,7 +606,7 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): # combining id__in with other filters response = client.get( ROUTE, - params={"id__in": ",".join(morphology_ids), "name__ilike": "%Filter Test Morphology 2%"}, + params={"id__in": morphology_ids, "name__ilike": "%Filter Test Morphology 2%"}, ) assert response.status_code == 200 data = response.json()["data"] diff --git a/tests/test_mtype.py b/tests/test_mtype.py index 01d3f8a1b..00aedac58 100644 --- a/tests/test_mtype.py +++ b/tests/test_mtype.py @@ -46,7 +46,7 @@ def test_retrieve(db, client, person_id): assert data[0] == with_creation_fields(items[5]) # test filter (in) - response = client.get(ROUTE, params={"pref_label__in": "pref_label_5,pref_label_6"}) + response = client.get(ROUTE, params={"pref_label__in": ["pref_label_5", "pref_label_6"]}) assert response.status_code == 200 data = response.json()["data"] assert len(data) == 2 diff --git a/tests/test_simulation_execution.py b/tests/test_simulation_execution.py index 174bb03d7..94ff2c03a 100644 --- a/tests/test_simulation_execution.py +++ b/tests/test_simulation_execution.py @@ -248,14 +248,16 @@ def test_filtering(client, models, root_circuit, simulation_result): assert len(data) == 2 data = assert_request( - client.get, url=ROUTE, params={"used__id__in": f"{root_circuit.id},{simulation_result.id}"} + client.get, + url=ROUTE, + params={"used__id__in": [str(root_circuit.id), str(simulation_result.id)]}, ).json()["data"] assert len(data) == 5 data = assert_request( client.get, url=ROUTE, - params={"generated__id__in": f"{root_circuit.id},{simulation_result.id}"}, + params={"generated__id__in": [str(root_circuit.id), str(simulation_result.id)]}, ).json()["data"] assert len(data) == 4 diff --git a/tests/test_simulation_generation.py b/tests/test_simulation_generation.py index f0a40ef66..a2c79f97a 100644 --- a/tests/test_simulation_generation.py +++ b/tests/test_simulation_generation.py @@ -227,14 +227,16 @@ def test_filtering(client, models, root_circuit, simulation_result): assert len(data) == 2 data = assert_request( - client.get, url=ROUTE, params={"used__id__in": f"{root_circuit.id},{simulation_result.id}"} + client.get, + url=ROUTE, + params={"used__id__in": [str(root_circuit.id), str(simulation_result.id)]}, ).json()["data"] assert len(data) == 5 data = assert_request( client.get, url=ROUTE, - params={"generated__id__in": f"{root_circuit.id},{simulation_result.id}"}, + params={"generated__id__in": [str(root_circuit.id), str(simulation_result.id)]}, ).json()["data"] assert len(data) == 4 From 941606997342c94972eedbd392ec52a1eb81f7fd Mon Sep 17 00:00:00 2001 From: Eleftherios Zisis Date: Thu, 3 Jul 2025 20:26:39 +0200 Subject: [PATCH 3/5] Fix --- tests/test_electrical_cell_recording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_electrical_cell_recording.py b/tests/test_electrical_cell_recording.py index 404aeeac3..ef8858e05 100644 --- a/tests/test_electrical_cell_recording.py +++ b/tests/test_electrical_cell_recording.py @@ -267,4 +267,4 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): url=ROUTE, params={"name__in": ["e-1", "e-2"]}, ).json()["data"] - assert [d["name"] for d in data] == ["e-1", "e-2"] + assert {d["name"] for d in data} == {"e-1", "e-2"} From a4914eae75c5eb67256f90019c8fb9fe22e7a1b6 Mon Sep 17 00:00:00 2001 From: Eleftherios Zisis Date: Mon, 7 Jul 2025 13:29:48 +0200 Subject: [PATCH 4/5] Make change backwards compatible --- app/filters/base.py | 14 +++++++++++- app/filters/common.py | 6 ++++- tests/test_electrical_cell_recording.py | 8 +++++++ tests/test_etype.py | 8 +++++++ tests/test_morphology.py | 29 +++++++++++++++++++++++++ tests/test_mtype.py | 7 ++++++ tests/test_simulation_execution.py | 14 ++++++++++++ tests/test_simulation_generation.py | 14 ++++++++++++ 8 files changed, 98 insertions(+), 2 deletions(-) diff --git a/app/filters/base.py b/app/filters/base.py index a2084cd17..447f8ee47 100644 --- a/app/filters/base.py +++ b/app/filters/base.py @@ -20,8 +20,20 @@ class Constants(Filter.Constants): @field_validator("*", mode="before") @classmethod - def split_str(cls, value, field): # noqa: ARG003 # pyright: ignore reportIncompatibleMethodOverride + def split_str(cls, value, field): # pyright: ignore reportIncompatibleMethodOverride """Prevent splitting field logic from parent class.""" + # backwards compatibility by splitting only comma separated single list elements that do not + # have space directly after the comma. e.g "a,b,c" will be split but not 'a, b, c'. + if field.field_name is not None and ( + field.field_name == cls.Constants.ordering_field_name + or field.field_name.endswith("__in") + or field.field_name.endswith("__not_in") + ): + return ( + value[0].split(",") + if value and len(value) == 1 and isinstance(value[0], str) and ", " not in value[0] + else value + ) return value @field_validator("order_by", check_fields=False) diff --git a/app/filters/common.py b/app/filters/common.py index f46e18f45..46538c431 100644 --- a/app/filters/common.py +++ b/app/filters/common.py @@ -21,7 +21,11 @@ class IdFilterMixin: id: uuid.UUID | None = None - id__in: list[uuid.UUID] | None = None + + # id__in needs to be a str for backwards compatibility when instead of a native list a comma + # separated string is provided, e.g. 'id1,id2' . With list[UUID] backwards compatibility would + # fail because of validation of the field which would be expected to be a UUID. + id__in: list[str] | None = None class NameFilterMixin: diff --git a/tests/test_electrical_cell_recording.py b/tests/test_electrical_cell_recording.py index ef8858e05..a7c32143c 100644 --- a/tests/test_electrical_cell_recording.py +++ b/tests/test_electrical_cell_recording.py @@ -268,3 +268,11 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): params={"name__in": ["e-1", "e-2"]}, ).json()["data"] assert {d["name"] for d in data} == {"e-1", "e-2"} + + # backwards compat + data = assert_request( + client.get, + url=ROUTE, + params={"name__in": "e-1,e-2"}, + ).json()["data"] + assert {d["name"] for d in data} == {"e-1", "e-2"} diff --git a/tests/test_etype.py b/tests/test_etype.py index 4c6662608..7801c57a4 100644 --- a/tests/test_etype.py +++ b/tests/test_etype.py @@ -45,6 +45,14 @@ def test_retrieve(db, client, person_id): assert data[0] == with_creation_fields(items[5]) # test filter (in) + + # backwards compat + response = client.get(ROUTE, params={"pref_label__in": "pref_label_5,pref_label_6"}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + assert data == [with_creation_fields(items[5]), with_creation_fields(items[6])] + response = client.get(ROUTE, params={"pref_label__in": ["pref_label_5", "pref_label_6"]}) assert response.status_code == 200 data = response.json()["data"] diff --git a/tests/test_morphology.py b/tests/test_morphology.py index 0a1107484..48f84752a 100644 --- a/tests/test_morphology.py +++ b/tests/test_morphology.py @@ -582,6 +582,15 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): # filtering by multiple IDs selected_ids = [morphology_ids[1], morphology_ids[3]] + + # backwards compat + response = client.get(ROUTE, params={"id__in": ",".join(selected_ids)}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + returned_ids = [item["id"] for item in data] + assert set(returned_ids) == set(selected_ids) + response = client.get(ROUTE, params={"id__in": selected_ids}) assert response.status_code == 200 data = response.json()["data"] @@ -590,6 +599,15 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): assert set(returned_ids) == set(selected_ids) # filtering by all IDs + + # backwards compat + response = client.get(ROUTE, params={"id__in": ",".join(morphology_ids)}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 5 + returned_ids = [item["id"] for item in data] + assert set(returned_ids) == set(morphology_ids) + response = client.get(ROUTE, params={"id__in": morphology_ids}) assert response.status_code == 200 data = response.json()["data"] @@ -604,6 +622,17 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): assert len(data) == 0 # combining id__in with other filters + + # backwards compat + response = client.get( + ROUTE, + params={"id__in": ",".join(morphology_ids), "name__ilike": "%Filter Test Morphology 2%"}, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["id"] == morphology_ids[2] + response = client.get( ROUTE, params={"id__in": morphology_ids, "name__ilike": "%Filter Test Morphology 2%"}, diff --git a/tests/test_mtype.py b/tests/test_mtype.py index 00aedac58..180262089 100644 --- a/tests/test_mtype.py +++ b/tests/test_mtype.py @@ -46,6 +46,13 @@ def test_retrieve(db, client, person_id): assert data[0] == with_creation_fields(items[5]) # test filter (in) + # backwards compat + response = client.get(ROUTE, params={"pref_label__in": "pref_label_5,pref_label_6"}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + assert data == [with_creation_fields(items[5]), with_creation_fields(items[6])] + response = client.get(ROUTE, params={"pref_label__in": ["pref_label_5", "pref_label_6"]}) assert response.status_code == 200 data = response.json()["data"] diff --git a/tests/test_simulation_execution.py b/tests/test_simulation_execution.py index 94ff2c03a..c73896c87 100644 --- a/tests/test_simulation_execution.py +++ b/tests/test_simulation_execution.py @@ -247,6 +247,12 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 2 + # backwards compat + data = assert_request( + client.get, url=ROUTE, params={"used__id__in": f"{root_circuit.id},{simulation_result.id}"} + ).json()["data"] + assert len(data) == 5 + data = assert_request( client.get, url=ROUTE, @@ -254,6 +260,14 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 5 + # backwards compat + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": f"{root_circuit.id},{simulation_result.id}"}, + ).json()["data"] + assert len(data) == 4 + data = assert_request( client.get, url=ROUTE, diff --git a/tests/test_simulation_generation.py b/tests/test_simulation_generation.py index a2c79f97a..5fe8bcd81 100644 --- a/tests/test_simulation_generation.py +++ b/tests/test_simulation_generation.py @@ -226,6 +226,12 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 2 + # backwards compat + data = assert_request( + client.get, url=ROUTE, params={"used__id__in": f"{root_circuit.id},{simulation_result.id}"} + ).json()["data"] + assert len(data) == 5 + data = assert_request( client.get, url=ROUTE, @@ -233,6 +239,14 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 5 + # backwards compat + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": f"{root_circuit.id},{simulation_result.id}"}, + ).json()["data"] + assert len(data) == 4 + data = assert_request( client.get, url=ROUTE, From dd965a9aaf097f60b6d896a82295d7c3c6f19674 Mon Sep 17 00:00:00 2001 From: Eleftherios Zisis Date: Mon, 7 Jul 2025 14:18:25 +0200 Subject: [PATCH 5/5] Add warning --- app/filters/base.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/app/filters/base.py b/app/filters/base.py index 447f8ee47..d32744fcd 100644 --- a/app/filters/base.py +++ b/app/filters/base.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import DeclarativeBase from app.db.model import Identifiable +from app.logger import L Aliases = dict[type[Identifiable], type[Identifiable] | dict[str, type[Identifiable]]] @@ -24,16 +25,25 @@ def split_str(cls, value, field): # pyright: ignore reportIncompatibleMethodOve """Prevent splitting field logic from parent class.""" # backwards compatibility by splitting only comma separated single list elements that do not # have space directly after the comma. e.g "a,b,c" will be split but not 'a, b, c'. - if field.field_name is not None and ( - field.field_name == cls.Constants.ordering_field_name - or field.field_name.endswith("__in") - or field.field_name.endswith("__not_in") + if ( + field.field_name is not None # noqa: PLR0916 + and ( + field.field_name == cls.Constants.ordering_field_name + or field.field_name.endswith("__in") + or field.field_name.endswith("__not_in") + ) + and value + and len(value) == 1 + and isinstance(value[0], str) + and "," in value[0] + and ", " not in value[0] ): - return ( - value[0].split(",") - if value and len(value) == 1 and isinstance(value[0], str) and ", " not in value[0] - else value + msg = ( + "Deprecated comma separated single-string IN query used instead of native list. " + f"Filter: field.config['title'] Field name: {field.field_name} Value: {value}" ) + L.warning(msg) + return value[0].split(",") return value @field_validator("order_by", check_fields=False)