From 7fe870611ec03a092d42f4e0370b6b449961d1de Mon Sep 17 00:00:00 2001 From: fderuiter <127706008+fderuiter@users.noreply.github.com> Date: Wed, 28 Jan 2026 18:27:13 +0000 Subject: [PATCH] Refactor ListGetEndpointMixin to separate sync/async logic Splits `_list_impl` and `_get_impl` into distinct `_list_sync`/`_list_async` and `_get_sync`/`_get_async` methods to enforce strict type safety and remove runtime `inspect` checks. Updates `UsersEndpoint` and `RecordsEndpoint` to use the new `_prepare_list_params` hook instead of overriding execution logic. Refactors unit tests to mock the specific sync/async methods instead of `_list_impl`. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> --- imednet/endpoints/_mixins.py | 145 +++++++++++------- imednet/endpoints/records.py | 29 ++-- imednet/endpoints/users.py | 38 ++--- tests/unit/endpoints/test_codings_endpoint.py | 2 +- tests/unit/endpoints/test_endpoints_async.py | 4 +- tests/unit/endpoints/test_forms_endpoint.py | 4 +- .../unit/endpoints/test_intervals_endpoint.py | 2 +- tests/unit/endpoints/test_queries_endpoint.py | 2 +- .../test_record_revisions_endpoint.py | 2 +- tests/unit/endpoints/test_records_endpoint.py | 4 +- tests/unit/endpoints/test_sites_endpoint.py | 2 +- .../unit/endpoints/test_subjects_endpoint.py | 2 +- tests/unit/endpoints/test_users_endpoint.py | 2 +- .../unit/endpoints/test_variables_endpoint.py | 2 +- tests/unit/endpoints/test_visits_endpoint.py | 2 +- 15 files changed, 128 insertions(+), 114 deletions(-) diff --git a/imednet/endpoints/_mixins.py b/imednet/endpoints/_mixins.py index db21b655..026db4d6 100644 --- a/imednet/endpoints/_mixins.py +++ b/imednet/endpoints/_mixins.py @@ -176,16 +176,16 @@ def _execute_sync_list( self._update_local_cache(result, study, has_filters, cache) return result - def _list_impl( + def _list_sync( self, - client: Client | AsyncClient, - paginator_cls: type[Paginator] | type[AsyncPaginator], + client: Client, + paginator_cls: type[Paginator], *, study_key: Optional[str] = None, refresh: bool = False, extra_params: Optional[Dict[str, Any]] = None, **filters: Any, - ) -> List[T] | Awaitable[List[T]]: + ) -> List[T]: study, cache, params, other_filters = self._prepare_list_params( study_key, refresh, extra_params, filters @@ -206,25 +206,52 @@ def _list_impl( paginator = paginator_cls(client, path, params=params, page_size=self.PAGE_SIZE) parse_func = self._resolve_parse_func() - if hasattr(paginator, "__aiter__"): - return self._execute_async_list( - cast(AsyncPaginator, paginator), parse_func, study, bool(other_filters), cache - ) + return self._execute_sync_list(paginator, parse_func, study, bool(other_filters), cache) - return self._execute_sync_list( - cast(Paginator, paginator), parse_func, study, bool(other_filters), cache + async def _list_async( + self, + client: AsyncClient, + paginator_cls: type[AsyncPaginator], + *, + study_key: Optional[str] = None, + refresh: bool = False, + extra_params: Optional[Dict[str, Any]] = None, + **filters: Any, + ) -> List[T]: + + study, cache, params, other_filters = self._prepare_list_params( + study_key, refresh, extra_params, filters ) - def _get_impl( + # Cache Hit Check + if self.requires_study_key: + if not study: + # Should have been caught in _prepare_list_params but strict typing requires check + raise ValueError("Study key must be provided or set in the context") + if cache is not None and not other_filters and not refresh and study in cache: + return cast(List[T], cache[study]) + else: + if cache is not None and not other_filters and not refresh: + return cast(List[T], cache) + + path = self._get_path(study) + paginator = paginator_cls(client, path, params=params, page_size=self.PAGE_SIZE) + parse_func = self._resolve_parse_func() + + return await self._execute_async_list( + paginator, parse_func, study, bool(other_filters), cache + ) + + def _get_sync( self, - client: Client | AsyncClient, - paginator_cls: type[Paginator] | type[AsyncPaginator], + client: Client, + paginator_cls: type[Paginator], *, study_key: Optional[str], item_id: Any, - ) -> T | Awaitable[T]: + ) -> T: filters = {self._id_param: item_id} - result = self._list_impl( + result = self._list_sync( client, paginator_cls, study_key=study_key, @@ -232,65 +259,67 @@ def _get_impl( **filters, ) - if inspect.isawaitable(result): - - async def _await() -> T: - items = await result - if not items: - if self.requires_study_key: - raise ValueError( - f"{self.MODEL.__name__} {item_id} not found in study {study_key}" - ) - raise ValueError(f"{self.MODEL.__name__} {item_id} not found") - return items[0] + if not result: + if self.requires_study_key: + raise ValueError(f"{self.MODEL.__name__} {item_id} not found in study {study_key}") + raise ValueError(f"{self.MODEL.__name__} {item_id} not found") + return result[0] - return _await() + async def _get_async( + self, + client: AsyncClient, + paginator_cls: type[AsyncPaginator], + *, + study_key: Optional[str], + item_id: Any, + ) -> T: + filters = {self._id_param: item_id} + result = await self._list_async( + client, + paginator_cls, + study_key=study_key, + refresh=True, + **filters, + ) - # Sync path - items = cast(List[T], result) - if not items: + if not result: if self.requires_study_key: raise ValueError(f"{self.MODEL.__name__} {item_id} not found in study {study_key}") raise ValueError(f"{self.MODEL.__name__} {item_id} not found") - return items[0] + return result[0] class ListGetEndpoint(BaseEndpoint, ListGetEndpointMixin[T]): """Endpoint base class implementing ``list`` and ``get`` helpers.""" - def _get_context( - self, is_async: bool - ) -> tuple[Client | AsyncClient, type[Paginator] | type[AsyncPaginator]]: - if is_async: - return self._require_async_client(), AsyncPaginator - return self._client, Paginator - - def _list_common(self, is_async: bool, **kwargs: Any) -> List[T] | Awaitable[List[T]]: - client, paginator = self._get_context(is_async) - return self._list_impl(client, paginator, **kwargs) - - def _get_common( - self, - is_async: bool, - *, - study_key: Optional[str], - item_id: Any, - ) -> T | Awaitable[T]: - client, paginator = self._get_context(is_async) - return self._get_impl(client, paginator, study_key=study_key, item_id=item_id) - def list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]: - return cast(List[T], self._list_common(False, study_key=study_key, **filters)) + return self._list_sync( + self._client, + Paginator, + study_key=study_key, + **filters, + ) async def async_list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]: - return await cast( - Awaitable[List[T]], self._list_common(True, study_key=study_key, **filters) + return await self._list_async( + self._require_async_client(), + AsyncPaginator, + study_key=study_key, + **filters, ) def get(self, study_key: Optional[str], item_id: Any) -> T: - return cast(T, self._get_common(False, study_key=study_key, item_id=item_id)) + return self._get_sync( + self._client, + Paginator, + study_key=study_key, + item_id=item_id, + ) async def async_get(self, study_key: Optional[str], item_id: Any) -> T: - return await cast( - Awaitable[T], self._get_common(True, study_key=study_key, item_id=item_id) + return await self._get_async( + self._require_async_client(), + AsyncPaginator, + study_key=study_key, + item_id=item_id, ) diff --git a/imednet/endpoints/records.py b/imednet/endpoints/records.py index 8c445023..f74ae689 100644 --- a/imednet/endpoints/records.py +++ b/imednet/endpoints/records.py @@ -131,20 +131,17 @@ async def async_create( response = await client.post(path, json=records_data, headers=headers) return Job.from_json(response.json()) - def _list_impl( + def _prepare_list_params( self, - client: Any, - paginator_cls: type[Any], - *, - study_key: Optional[str] = None, - record_data_filter: Optional[str] = None, - **filters: Any, - ) -> Any: - extra = {"recordDataFilter": record_data_filter} if record_data_filter else None - return super()._list_impl( - client, - paginator_cls, - study_key=study_key, - extra_params=extra, - **filters, - ) + study_key: Optional[str], + refresh: bool, + extra_params: Optional[Dict[str, Any]], + filters: Dict[str, Any], + ) -> tuple[Optional[str], Any, Dict[str, Any], Dict[str, Any]]: + record_data_filter = filters.pop("record_data_filter", None) + + if record_data_filter: + extra_params = extra_params or {} + extra_params["recordDataFilter"] = record_data_filter + + return super()._prepare_list_params(study_key, refresh, extra_params, filters) diff --git a/imednet/endpoints/users.py b/imednet/endpoints/users.py index 58d35783..f7f44de1 100644 --- a/imednet/endpoints/users.py +++ b/imednet/endpoints/users.py @@ -1,10 +1,7 @@ """Endpoint for managing users in a study.""" -from typing import Any, Awaitable, Dict, List, Optional, Union +from typing import Any, Dict, Optional -from imednet.core.async_client import AsyncClient -from imednet.core.client import Client -from imednet.core.paginator import AsyncPaginator, Paginator from imednet.endpoints._mixins import ListGetEndpoint from imednet.models.users import User @@ -21,25 +18,16 @@ class UsersEndpoint(ListGetEndpoint[User]): _id_param = "userId" _pop_study_filter = True - def _list_impl( + def _prepare_list_params( self, - client: Client | AsyncClient, - paginator_cls: Union[type[Paginator], type[AsyncPaginator]], - *, - study_key: Optional[str] = None, - refresh: bool = False, - extra_params: Optional[Dict[str, Any]] = None, - include_inactive: bool = False, - **filters: Any, - ) -> List[User] | Awaitable[List[User]]: - params = extra_params or {} - params["includeInactive"] = str(include_inactive).lower() - - return super()._list_impl( - client, - paginator_cls, - study_key=study_key, - refresh=refresh, - extra_params=params, - **filters, - ) + study_key: Optional[str], + refresh: bool, + extra_params: Optional[Dict[str, Any]], + filters: Dict[str, Any], + ) -> tuple[Optional[str], Any, Dict[str, Any], Dict[str, Any]]: + include_inactive = filters.pop("include_inactive", False) + + extra_params = extra_params or {} + extra_params["includeInactive"] = str(include_inactive).lower() + + return super()._prepare_list_params(study_key, refresh, extra_params, filters) diff --git a/tests/unit/endpoints/test_codings_endpoint.py b/tests/unit/endpoints/test_codings_endpoint.py index 4b4d32fc..d65368c6 100644 --- a/tests/unit/endpoints/test_codings_endpoint.py +++ b/tests/unit/endpoints/test_codings_endpoint.py @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(codings.CodingsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(codings.CodingsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", "x") diff --git a/tests/unit/endpoints/test_endpoints_async.py b/tests/unit/endpoints/test_endpoints_async.py index 54cfd5eb..134a7ed4 100644 --- a/tests/unit/endpoints/test_endpoints_async.py +++ b/tests/unit/endpoints/test_endpoints_async.py @@ -166,7 +166,7 @@ async def fake_impl(self, client, paginator, *, study_key=None, **filters): called["filters"] = filters return [Record(record_id=1)] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl) rec = await ep.async_get("S1", 1) @@ -181,7 +181,7 @@ async def test_async_get_record_not_found(monkeypatch, dummy_client, context, re async def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl) with pytest.raises(ValueError): await ep.async_get("S1", 1) diff --git a/tests/unit/endpoints/test_forms_endpoint.py b/tests/unit/endpoints/test_forms_endpoint.py index d0b55156..b1da238f 100644 --- a/tests/unit/endpoints/test_forms_endpoint.py +++ b/tests/unit/endpoints/test_forms_endpoint.py @@ -34,7 +34,7 @@ def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filte called["filters"] = filters return [Form(form_id=1)] - monkeypatch.setattr(forms.FormsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(forms.FormsEndpoint, "_list_sync", fake_impl) res = ep.get("S1", 1) @@ -48,7 +48,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(forms.FormsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(forms.FormsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_intervals_endpoint.py b/tests/unit/endpoints/test_intervals_endpoint.py index 00b60906..02187e90 100644 --- a/tests/unit/endpoints/test_intervals_endpoint.py +++ b/tests/unit/endpoints/test_intervals_endpoint.py @@ -27,7 +27,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(intervals.IntervalsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(intervals.IntervalsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_queries_endpoint.py b/tests/unit/endpoints/test_queries_endpoint.py index 379a74e6..8873b991 100644 --- a/tests/unit/endpoints/test_queries_endpoint.py +++ b/tests/unit/endpoints/test_queries_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(queries.QueriesEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(queries.QueriesEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_record_revisions_endpoint.py b/tests/unit/endpoints/test_record_revisions_endpoint.py index a5be9196..bc4a2eee 100644 --- a/tests/unit/endpoints/test_record_revisions_endpoint.py +++ b/tests/unit/endpoints/test_record_revisions_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(record_revisions.RecordRevisionsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(record_revisions.RecordRevisionsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_records_endpoint.py b/tests/unit/endpoints/test_records_endpoint.py index fb1d7f8b..7f742d4b 100644 --- a/tests/unit/endpoints/test_records_endpoint.py +++ b/tests/unit/endpoints/test_records_endpoint.py @@ -35,7 +35,7 @@ def fake_impl(self, client, paginator, *, study_key=None, **filters): called["filters"] = filters return [Record(record_id=1)] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_sync", fake_impl) res = ep.get("S1", 1) @@ -49,7 +49,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(records.RecordsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_sites_endpoint.py b/tests/unit/endpoints/test_sites_endpoint.py index 386f3930..653a43b5 100644 --- a/tests/unit/endpoints/test_sites_endpoint.py +++ b/tests/unit/endpoints/test_sites_endpoint.py @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(sites.SitesEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(sites.SitesEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_subjects_endpoint.py b/tests/unit/endpoints/test_subjects_endpoint.py index b2513807..a45c122b 100644 --- a/tests/unit/endpoints/test_subjects_endpoint.py +++ b/tests/unit/endpoints/test_subjects_endpoint.py @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(subjects.SubjectsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(subjects.SubjectsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", "X") diff --git a/tests/unit/endpoints/test_users_endpoint.py b/tests/unit/endpoints/test_users_endpoint.py index 4c8e76e9..2dc00e22 100644 --- a/tests/unit/endpoints/test_users_endpoint.py +++ b/tests/unit/endpoints/test_users_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(users.UsersEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(users.UsersEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_variables_endpoint.py b/tests/unit/endpoints/test_variables_endpoint.py index 7b65045f..0627737f 100644 --- a/tests/unit/endpoints/test_variables_endpoint.py +++ b/tests/unit/endpoints/test_variables_endpoint.py @@ -29,7 +29,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(variables.VariablesEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(variables.VariablesEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1) diff --git a/tests/unit/endpoints/test_visits_endpoint.py b/tests/unit/endpoints/test_visits_endpoint.py index bf3f8ccf..9caf8703 100644 --- a/tests/unit/endpoints/test_visits_endpoint.py +++ b/tests/unit/endpoints/test_visits_endpoint.py @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context): def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters): return [] - monkeypatch.setattr(visits.VisitsEndpoint, "_list_impl", fake_impl) + monkeypatch.setattr(visits.VisitsEndpoint, "_list_sync", fake_impl) with pytest.raises(ValueError): ep.get("S1", 1)