diff --git a/app/dependencies/filter.py b/app/dependencies/filter.py new file mode 100644 index 000000000..66acc4675 --- /dev/null +++ b/app/dependencies/filter.py @@ -0,0 +1,55 @@ +from copy import deepcopy +from typing import Any, get_args, get_origin + +from fastapi import Depends, Query, params +from fastapi.exceptions import RequestValidationError +from fastapi_filter.base.filter import BaseFilterModel +from pydantic import ValidationError, create_model + + +def _list_to_query_fields(filter_model: type[BaseFilterModel]): + fields = {} + for name, f in filter_model.model_fields.items(): + field_info = deepcopy(f) + annotation = f.annotation + + if ( + annotation is list + or get_origin(annotation) is list + or any(get_origin(a) is list for a in get_args(annotation)) + ) and type(field_info.default) is not params.Query: + field_info.default = Query(default=field_info.default) + + fields[name] = (f.annotation, field_info) + + return fields + + +def FilterDepends(filter_model: type[BaseFilterModel], *, by_alias: bool = False, **_) -> Any: + """Use a hack to treat lists as query parameters. + + What we do is loop through the fields of a filter and assign any `list` field a default value of + `Query` so that FastAPI knows it should be treated a query parameter and not body. + + When we apply the filter, we build the original filter to properly validate the data (i.e. can + the string be parsed and formatted as a list of ?) + """ + fields = _list_to_query_fields(filter_model) + GeneratedFilter = create_model(filter_model.__class__.__name__, **fields) # noqa: N806 + + class FilterWrapper(GeneratedFilter): # type: ignore[misc,valid-type] + def __new__(cls, *args, **kwargs): + try: + instance = GeneratedFilter(*args, **kwargs) + data = instance.model_dump( + exclude_unset=True, exclude_defaults=True, by_alias=by_alias + ) + if original_filter := getattr(filter_model.Constants, "original_filter", None): + prefix = f"{filter_model.Constants.prefix}__" + stripped = {k.removeprefix(prefix): v for k, v in data.items()} + return original_filter(**stripped) + return filter_model(**data) + except ValidationError as e: + raise RequestValidationError(e.errors()) from e + + return Depends(FilterWrapper) diff --git a/app/filters/activity.py b/app/filters/activity.py index cf4806e86..8b039a896 100644 --- a/app/filters/activity.py +++ b/app/filters/activity.py @@ -1,8 +1,9 @@ from datetime import datetime from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix +from app.dependencies.filter import FilterDepends from app.filters.common import ( CreationFilterMixin, CreatorFilterMixin, diff --git a/app/filters/base.py b/app/filters/base.py index a7747ee06..d32744fcd 100644 --- a/app/filters/base.py +++ b/app/filters/base.py @@ -8,6 +8,7 @@ from sqlalchemy.orm import DeclarativeBase from app.db.model import Identifiable +from app.logger import L Aliases = dict[type[Identifiable], type[Identifiable] | dict[str, type[Identifiable]]] @@ -18,6 +19,33 @@ class CustomFilter[T: DeclarativeBase](Filter): class Constants(Filter.Constants): ordering_model_fields: list[str] + @field_validator("*", mode="before") + @classmethod + def split_str(cls, value, field): # pyright: ignore reportIncompatibleMethodOverride + """Prevent splitting field logic from parent class.""" + # backwards compatibility by splitting only comma separated single list elements that do not + # have space directly after the comma. e.g "a,b,c" will be split but not 'a, b, c'. + if ( + field.field_name is not None # noqa: PLR0916 + and ( + field.field_name == cls.Constants.ordering_field_name + or field.field_name.endswith("__in") + or field.field_name.endswith("__not_in") + ) + and value + and len(value) == 1 + and isinstance(value[0], str) + and "," in value[0] + and ", " not in value[0] + ): + msg = ( + "Deprecated comma separated single-string IN query used instead of native list. " + f"Filter: field.config['title'] Field name: {field.field_name} Value: {value}" + ) + L.warning(msg) + return value[0].split(",") + return value + @field_validator("order_by", check_fields=False) @classmethod def restrict_sortable_fields(cls, value: list[str]): diff --git a/app/filters/brain_atlas.py b/app/filters/brain_atlas.py index 89d51a985..aca331356 100644 --- a/app/filters/brain_atlas.py +++ b/app/filters/brain_atlas.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import BrainAtlas, BrainAtlasRegion +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import IdFilterMixin, NameFilterMixin, SpeciesFilterMixin diff --git a/app/filters/brain_region.py b/app/filters/brain_region.py index 1b0728949..09879b204 100644 --- a/app/filters/brain_region.py +++ b/app/filters/brain_region.py @@ -2,10 +2,10 @@ from typing import Annotated import sqlalchemy as sa -from fastapi_filter import FilterDepends from sqlalchemy.orm import aliased from app.db.model import BrainRegion, BrainRegionHierarchy +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import NameFilterMixin diff --git a/app/filters/brain_region_hierarchy.py b/app/filters/brain_region_hierarchy.py index 33c19f209..0a6bafd68 100644 --- a/app/filters/brain_region_hierarchy.py +++ b/app/filters/brain_region_hierarchy.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import BrainRegionHierarchy +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import IdFilterMixin, NameFilterMixin diff --git a/app/filters/cell_composition.py b/app/filters/cell_composition.py index d9e443f0a..348e01e39 100644 --- a/app/filters/cell_composition.py +++ b/app/filters/cell_composition.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import CellComposition +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( EntityFilterMixin, diff --git a/app/filters/circuit.py b/app/filters/circuit.py index 44eb1afca..27156d350 100644 --- a/app/filters/circuit.py +++ b/app/filters/circuit.py @@ -2,9 +2,8 @@ from datetime import datetime from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import Circuit +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/common.py b/app/filters/common.py index 896d40b75..46538c431 100644 --- a/app/filters/common.py +++ b/app/filters/common.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import ( Agent, @@ -15,12 +15,17 @@ Strain, Subject, ) +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter class IdFilterMixin: id: uuid.UUID | None = None - id__in: list[uuid.UUID] | None = None + + # id__in needs to be a str for backwards compatibility when instead of a native list a comma + # separated string is provided, e.g. 'id1,id2' . With list[UUID] backwards compatibility would + # fail because of validation of the field which would be expected to be a UUID. + id__in: list[str] | None = None class NameFilterMixin: diff --git a/app/filters/density.py b/app/filters/density.py index 0d2a91efa..25fcb1953 100644 --- a/app/filters/density.py +++ b/app/filters/density.py @@ -1,12 +1,11 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import ( ExperimentalBoutonDensity, ExperimentalNeuronDensity, ExperimentalSynapsesPerConnection, ) +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilter, diff --git a/app/filters/electrical_cell_recording.py b/app/filters/electrical_cell_recording.py index e50694008..49405b007 100644 --- a/app/filters/electrical_cell_recording.py +++ b/app/filters/electrical_cell_recording.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import ElectricalCellRecording +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/emodel.py b/app/filters/emodel.py index edcd970e7..0e4422e24 100644 --- a/app/filters/emodel.py +++ b/app/filters/emodel.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import EModel +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/entity.py b/app/filters/entity.py index 066ac67aa..024864b2c 100644 --- a/app/filters/entity.py +++ b/app/filters/entity.py @@ -1,9 +1,8 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import Entity from app.db.types import EntityType +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter diff --git a/app/filters/ion_channel_model.py b/app/filters/ion_channel_model.py index 682072f80..310f166a7 100644 --- a/app/filters/ion_channel_model.py +++ b/app/filters/ion_channel_model.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import IonChannelModel +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/measurement_annotation.py b/app/filters/measurement_annotation.py index 97d4c4613..e55d46e6c 100644 --- a/app/filters/measurement_annotation.py +++ b/app/filters/measurement_annotation.py @@ -1,11 +1,12 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import MeasurementAnnotation, MeasurementItem, MeasurementKind from app.db.types import MeasurementStatistic, MeasurementUnit, StructuralDomain from app.db.utils import MeasurableEntityType +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import CreationFilterMixin diff --git a/app/filters/memodel.py b/app/filters/memodel.py index bdc0089d9..5594f1d40 100644 --- a/app/filters/memodel.py +++ b/app/filters/memodel.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import MEModel, ValidationStatus +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/memodel_calibration_result.py b/app/filters/memodel_calibration_result.py index 3fc629746..0358c0063 100644 --- a/app/filters/memodel_calibration_result.py +++ b/app/filters/memodel_calibration_result.py @@ -1,9 +1,8 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import MEModelCalibrationResult +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import EntityFilterMixin diff --git a/app/filters/morphology.py b/app/filters/morphology.py index 2853486e3..a4ffd947d 100644 --- a/app/filters/morphology.py +++ b/app/filters/morphology.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import ReconstructionMorphology +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/person.py b/app/filters/person.py index 04278d94b..7a82594fa 100644 --- a/app/filters/person.py +++ b/app/filters/person.py @@ -1,9 +1,8 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import Person +from app.dependencies.filter import FilterDepends from app.filters.common import AgentFilter, CreatorFilterMixin diff --git a/app/filters/simulation.py b/app/filters/simulation.py index e9fe2197b..805e6871f 100644 --- a/app/filters/simulation.py +++ b/app/filters/simulation.py @@ -1,9 +1,10 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import Simulation +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( ContributionFilterMixin, diff --git a/app/filters/simulation_campaign.py b/app/filters/simulation_campaign.py index c49298f8a..f8f93f1bc 100644 --- a/app/filters/simulation_campaign.py +++ b/app/filters/simulation_campaign.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SimulationCampaign +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import EntityFilterMixin, NameFilterMixin from app.filters.simulation import NestedSimulationFilter, NestedSimulationFilterDep diff --git a/app/filters/simulation_execution.py b/app/filters/simulation_execution.py index aacfb39a7..9dc6f19a9 100644 --- a/app/filters/simulation_execution.py +++ b/app/filters/simulation_execution.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SimulationExecution +from app.dependencies.filter import FilterDepends from app.filters.activity import ActivityFilterMixin from app.filters.base import CustomFilter diff --git a/app/filters/simulation_generation.py b/app/filters/simulation_generation.py index 1dddd40ba..bb6d3da72 100644 --- a/app/filters/simulation_generation.py +++ b/app/filters/simulation_generation.py @@ -1,8 +1,7 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SimulationGeneration +from app.dependencies.filter import FilterDepends from app.filters.activity import ActivityFilterMixin from app.filters.base import CustomFilter diff --git a/app/filters/simulation_result.py b/app/filters/simulation_result.py index d9f478e24..66a1f563c 100644 --- a/app/filters/simulation_result.py +++ b/app/filters/simulation_result.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import SimulationResult +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( ContributionFilterMixin, diff --git a/app/filters/single_neuron_simulation.py b/app/filters/single_neuron_simulation.py index e441a3df1..d37ab3e14 100644 --- a/app/filters/single_neuron_simulation.py +++ b/app/filters/single_neuron_simulation.py @@ -1,9 +1,8 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SingleNeuronSimulation from app.db.types import SingleNeuronSimulationStatus +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/single_neuron_synaptome.py b/app/filters/single_neuron_synaptome.py index e70c343bc..2335a01a7 100644 --- a/app/filters/single_neuron_synaptome.py +++ b/app/filters/single_neuron_synaptome.py @@ -1,8 +1,9 @@ from typing import Annotated -from fastapi_filter import FilterDepends, with_prefix +from fastapi_filter import with_prefix from app.db.model import SingleNeuronSynaptome +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/single_neuron_synaptome_simulation.py b/app/filters/single_neuron_synaptome_simulation.py index cf5e66694..a77cde781 100644 --- a/app/filters/single_neuron_synaptome_simulation.py +++ b/app/filters/single_neuron_synaptome_simulation.py @@ -1,9 +1,8 @@ from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import SingleNeuronSynaptomeSimulation from app.db.types import SingleNeuronSimulationStatus +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import ( BrainRegionFilterMixin, diff --git a/app/filters/validation_result.py b/app/filters/validation_result.py index 632fcd61e..86a5a33e2 100644 --- a/app/filters/validation_result.py +++ b/app/filters/validation_result.py @@ -1,9 +1,8 @@ import uuid from typing import Annotated -from fastapi_filter import FilterDepends - from app.db.model import ValidationResult +from app.dependencies.filter import FilterDepends from app.filters.base import CustomFilter from app.filters.common import EntityFilterMixin, NameFilterMixin diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 7279efb96..5ee987ad9 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -181,7 +181,7 @@ def test_filtering(client, root_circuit, models): url=ROUTE, params={ "root_circuit_id": str(root_circuit.id), - "scale__in": "single,whole_brain", + "scale__in": ["single", "whole_brain"], }, ).json()["data"] assert len(data) == 4 diff --git a/tests/test_electrical_cell_recording.py b/tests/test_electrical_cell_recording.py index 8860bf46c..a7c32143c 100644 --- a/tests/test_electrical_cell_recording.py +++ b/tests/test_electrical_cell_recording.py @@ -211,7 +211,7 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): db, [ Subject( - name=f"my-subject-{i}", + name=f"subject-{i}", description="my-description", species_id=sp.id, strain_id=None, @@ -244,7 +244,6 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): for i, subject in enumerate(subjects) ], ) - data = assert_request(client.get, url=ROUTE).json()["data"] assert len(data) == len(models) @@ -257,3 +256,23 @@ def test_filtering(db, client, electrical_cell_recording_json_data, person_id): "data" ] assert len(data) == 2 + + data = assert_request(client.get, url=ROUTE, params="subject__species__name=species-2").json()[ + "data" + ] + assert len(data) == 2 + + data = assert_request( + client.get, + url=ROUTE, + params={"name__in": ["e-1", "e-2"]}, + ).json()["data"] + assert {d["name"] for d in data} == {"e-1", "e-2"} + + # backwards compat + data = assert_request( + client.get, + url=ROUTE, + params={"name__in": "e-1,e-2"}, + ).json()["data"] + assert {d["name"] for d in data} == {"e-1", "e-2"} diff --git a/tests/test_etype.py b/tests/test_etype.py index 0848e27ab..7801c57a4 100644 --- a/tests/test_etype.py +++ b/tests/test_etype.py @@ -45,12 +45,20 @@ def test_retrieve(db, client, person_id): assert data[0] == with_creation_fields(items[5]) # test filter (in) + + # backwards compat response = client.get(ROUTE, params={"pref_label__in": "pref_label_5,pref_label_6"}) assert response.status_code == 200 data = response.json()["data"] assert len(data) == 2 assert data == [with_creation_fields(items[5]), with_creation_fields(items[6])] + response = client.get(ROUTE, params={"pref_label__in": ["pref_label_5", "pref_label_6"]}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + assert data == [with_creation_fields(items[5]), with_creation_fields(items[6])] + response = client.get(f"{ROUTE}/{etypes[0].id}") assert response.status_code == 200 data = response.json() diff --git a/tests/test_morphology.py b/tests/test_morphology.py index b96e5d5d7..48f84752a 100644 --- a/tests/test_morphology.py +++ b/tests/test_morphology.py @@ -582,6 +582,8 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): # filtering by multiple IDs selected_ids = [morphology_ids[1], morphology_ids[3]] + + # backwards compat response = client.get(ROUTE, params={"id__in": ",".join(selected_ids)}) assert response.status_code == 200 data = response.json()["data"] @@ -589,7 +591,16 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): returned_ids = [item["id"] for item in data] assert set(returned_ids) == set(selected_ids) + response = client.get(ROUTE, params={"id__in": selected_ids}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + returned_ids = [item["id"] for item in data] + assert set(returned_ids) == set(selected_ids) + # filtering by all IDs + + # backwards compat response = client.get(ROUTE, params={"id__in": ",".join(morphology_ids)}) assert response.status_code == 200 data = response.json()["data"] @@ -597,6 +608,13 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): returned_ids = [item["id"] for item in data] assert set(returned_ids) == set(morphology_ids) + response = client.get(ROUTE, params={"id__in": morphology_ids}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 5 + returned_ids = [item["id"] for item in data] + assert set(returned_ids) == set(morphology_ids) + # filtering by non-existent ID response = client.get(ROUTE, params={"id__in": MISSING_ID}) assert response.status_code == 200 @@ -604,6 +622,8 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): assert len(data) == 0 # combining id__in with other filters + + # backwards compat response = client.get( ROUTE, params={"id__in": ",".join(morphology_ids), "name__ilike": "%Filter Test Morphology 2%"}, @@ -613,6 +633,15 @@ def test_filter_by_id__in(db, client, brain_region_id, person_id): assert len(data) == 1 assert data[0]["id"] == morphology_ids[2] + response = client.get( + ROUTE, + params={"id__in": morphology_ids, "name__ilike": "%Filter Test Morphology 2%"}, + ) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 1 + assert data[0]["id"] == morphology_ids[2] + def test_brain_region_filter(db, client, brain_region_hierarchy_id, species_id, person_id): def create_model_function(_db, name, brain_region_id): diff --git a/tests/test_mtype.py b/tests/test_mtype.py index 01d3f8a1b..180262089 100644 --- a/tests/test_mtype.py +++ b/tests/test_mtype.py @@ -46,12 +46,19 @@ def test_retrieve(db, client, person_id): assert data[0] == with_creation_fields(items[5]) # test filter (in) + # backwards compat response = client.get(ROUTE, params={"pref_label__in": "pref_label_5,pref_label_6"}) assert response.status_code == 200 data = response.json()["data"] assert len(data) == 2 assert data == [with_creation_fields(items[5]), with_creation_fields(items[6])] + response = client.get(ROUTE, params={"pref_label__in": ["pref_label_5", "pref_label_6"]}) + assert response.status_code == 200 + data = response.json()["data"] + assert len(data) == 2 + assert data == [with_creation_fields(items[5]), with_creation_fields(items[6])] + response = client.get(f"{ROUTE}/{mtypes[0].id}") assert response.status_code == 200 data = response.json() diff --git a/tests/test_simulation_execution.py b/tests/test_simulation_execution.py index 174bb03d7..c73896c87 100644 --- a/tests/test_simulation_execution.py +++ b/tests/test_simulation_execution.py @@ -247,11 +247,20 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 2 + # backwards compat data = assert_request( client.get, url=ROUTE, params={"used__id__in": f"{root_circuit.id},{simulation_result.id}"} ).json()["data"] assert len(data) == 5 + data = assert_request( + client.get, + url=ROUTE, + params={"used__id__in": [str(root_circuit.id), str(simulation_result.id)]}, + ).json()["data"] + assert len(data) == 5 + + # backwards compat data = assert_request( client.get, url=ROUTE, @@ -259,6 +268,13 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 4 + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": [str(root_circuit.id), str(simulation_result.id)]}, + ).json()["data"] + assert len(data) == 4 + def test_delete_one(db, client, models): # sanity check diff --git a/tests/test_simulation_generation.py b/tests/test_simulation_generation.py index f0a40ef66..5fe8bcd81 100644 --- a/tests/test_simulation_generation.py +++ b/tests/test_simulation_generation.py @@ -226,11 +226,20 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 2 + # backwards compat data = assert_request( client.get, url=ROUTE, params={"used__id__in": f"{root_circuit.id},{simulation_result.id}"} ).json()["data"] assert len(data) == 5 + data = assert_request( + client.get, + url=ROUTE, + params={"used__id__in": [str(root_circuit.id), str(simulation_result.id)]}, + ).json()["data"] + assert len(data) == 5 + + # backwards compat data = assert_request( client.get, url=ROUTE, @@ -238,6 +247,13 @@ def test_filtering(client, models, root_circuit, simulation_result): ).json()["data"] assert len(data) == 4 + data = assert_request( + client.get, + url=ROUTE, + params={"generated__id__in": [str(root_circuit.id), str(simulation_result.id)]}, + ).json()["data"] + assert len(data) == 4 + def test_delete_one(db, client, models): # sanity check