Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 87 additions & 58 deletions imednet/endpoints/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,16 @@ def _execute_sync_list(
self._update_local_cache(result, study, has_filters, cache)
return result

def _list_impl(
def _list_sync(
self,
client: Client | AsyncClient,
paginator_cls: type[Paginator] | type[AsyncPaginator],
client: Client,
paginator_cls: type[Paginator],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T] | Awaitable[List[T]]:
) -> List[T]:

study, cache, params, other_filters = self._prepare_list_params(
study_key, refresh, extra_params, filters
Expand All @@ -206,91 +206,120 @@ def _list_impl(
paginator = paginator_cls(client, path, params=params, page_size=self.PAGE_SIZE)
parse_func = self._resolve_parse_func()

if hasattr(paginator, "__aiter__"):
return self._execute_async_list(
cast(AsyncPaginator, paginator), parse_func, study, bool(other_filters), cache
)
return self._execute_sync_list(paginator, parse_func, study, bool(other_filters), cache)

return self._execute_sync_list(
cast(Paginator, paginator), parse_func, study, bool(other_filters), cache
async def _list_async(
self,
client: AsyncClient,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
**filters: Any,
) -> List[T]:

study, cache, params, other_filters = self._prepare_list_params(
study_key, refresh, extra_params, filters
)

def _get_impl(
# Cache Hit Check
if self.requires_study_key:
if not study:
# Should have been caught in _prepare_list_params but strict typing requires check
raise ValueError("Study key must be provided or set in the context")
if cache is not None and not other_filters and not refresh and study in cache:
return cast(List[T], cache[study])
else:
if cache is not None and not other_filters and not refresh:
return cast(List[T], cache)

path = self._get_path(study)
paginator = paginator_cls(client, path, params=params, page_size=self.PAGE_SIZE)
parse_func = self._resolve_parse_func()

return await self._execute_async_list(
paginator, parse_func, study, bool(other_filters), cache
)

def _get_sync(
self,
client: Client | AsyncClient,
paginator_cls: type[Paginator] | type[AsyncPaginator],
client: Client,
paginator_cls: type[Paginator],
*,
study_key: Optional[str],
item_id: Any,
) -> T | Awaitable[T]:
) -> T:
filters = {self._id_param: item_id}
result = self._list_impl(
result = self._list_sync(
client,
paginator_cls,
study_key=study_key,
refresh=True,
**filters,
)

if inspect.isawaitable(result):

async def _await() -> T:
items = await result
if not items:
if self.requires_study_key:
raise ValueError(
f"{self.MODEL.__name__} {item_id} not found in study {study_key}"
)
raise ValueError(f"{self.MODEL.__name__} {item_id} not found")
return items[0]
if not result:
if self.requires_study_key:
raise ValueError(f"{self.MODEL.__name__} {item_id} not found in study {study_key}")
raise ValueError(f"{self.MODEL.__name__} {item_id} not found")
return result[0]

return _await()
async def _get_async(
self,
client: AsyncClient,
paginator_cls: type[AsyncPaginator],
*,
study_key: Optional[str],
item_id: Any,
) -> T:
filters = {self._id_param: item_id}
result = await self._list_async(
client,
paginator_cls,
study_key=study_key,
refresh=True,
**filters,
)

# Sync path
items = cast(List[T], result)
if not items:
if not result:
if self.requires_study_key:
raise ValueError(f"{self.MODEL.__name__} {item_id} not found in study {study_key}")
raise ValueError(f"{self.MODEL.__name__} {item_id} not found")
return items[0]
return result[0]


class ListGetEndpoint(BaseEndpoint, ListGetEndpointMixin[T]):
"""Endpoint base class implementing ``list`` and ``get`` helpers."""

def _get_context(
self, is_async: bool
) -> tuple[Client | AsyncClient, type[Paginator] | type[AsyncPaginator]]:
if is_async:
return self._require_async_client(), AsyncPaginator
return self._client, Paginator

def _list_common(self, is_async: bool, **kwargs: Any) -> List[T] | Awaitable[List[T]]:
client, paginator = self._get_context(is_async)
return self._list_impl(client, paginator, **kwargs)

def _get_common(
self,
is_async: bool,
*,
study_key: Optional[str],
item_id: Any,
) -> T | Awaitable[T]:
client, paginator = self._get_context(is_async)
return self._get_impl(client, paginator, study_key=study_key, item_id=item_id)

def list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]:
return cast(List[T], self._list_common(False, study_key=study_key, **filters))
return self._list_sync(
self._client,
Paginator,
study_key=study_key,
**filters,
)

async def async_list(self, study_key: Optional[str] = None, **filters: Any) -> List[T]:
return await cast(
Awaitable[List[T]], self._list_common(True, study_key=study_key, **filters)
return await self._list_async(
self._require_async_client(),
AsyncPaginator,
study_key=study_key,
**filters,
)

def get(self, study_key: Optional[str], item_id: Any) -> T:
return cast(T, self._get_common(False, study_key=study_key, item_id=item_id))
return self._get_sync(
self._client,
Paginator,
study_key=study_key,
item_id=item_id,
)

async def async_get(self, study_key: Optional[str], item_id: Any) -> T:
return await cast(
Awaitable[T], self._get_common(True, study_key=study_key, item_id=item_id)
return await self._get_async(
self._require_async_client(),
AsyncPaginator,
study_key=study_key,
item_id=item_id,
)
29 changes: 13 additions & 16 deletions imednet/endpoints/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,17 @@ async def async_create(
response = await client.post(path, json=records_data, headers=headers)
return Job.from_json(response.json())

def _list_impl(
def _prepare_list_params(
self,
client: Any,
paginator_cls: type[Any],
*,
study_key: Optional[str] = None,
record_data_filter: Optional[str] = None,
**filters: Any,
) -> Any:
extra = {"recordDataFilter": record_data_filter} if record_data_filter else None
return super()._list_impl(
client,
paginator_cls,
study_key=study_key,
extra_params=extra,
**filters,
)
study_key: Optional[str],
refresh: bool,
extra_params: Optional[Dict[str, Any]],
filters: Dict[str, Any],
) -> tuple[Optional[str], Any, Dict[str, Any], Dict[str, Any]]:
record_data_filter = filters.pop("record_data_filter", None)

if record_data_filter:
extra_params = extra_params or {}
extra_params["recordDataFilter"] = record_data_filter

return super()._prepare_list_params(study_key, refresh, extra_params, filters)
38 changes: 13 additions & 25 deletions imednet/endpoints/users.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
"""Endpoint for managing users in a study."""

from typing import Any, Awaitable, Dict, List, Optional, Union
from typing import Any, Dict, Optional

from imednet.core.async_client import AsyncClient
from imednet.core.client import Client
from imednet.core.paginator import AsyncPaginator, Paginator
from imednet.endpoints._mixins import ListGetEndpoint
from imednet.models.users import User

Expand All @@ -21,25 +18,16 @@ class UsersEndpoint(ListGetEndpoint[User]):
_id_param = "userId"
_pop_study_filter = True

def _list_impl(
def _prepare_list_params(
self,
client: Client | AsyncClient,
paginator_cls: Union[type[Paginator], type[AsyncPaginator]],
*,
study_key: Optional[str] = None,
refresh: bool = False,
extra_params: Optional[Dict[str, Any]] = None,
include_inactive: bool = False,
**filters: Any,
) -> List[User] | Awaitable[List[User]]:
params = extra_params or {}
params["includeInactive"] = str(include_inactive).lower()

return super()._list_impl(
client,
paginator_cls,
study_key=study_key,
refresh=refresh,
extra_params=params,
**filters,
)
study_key: Optional[str],
refresh: bool,
extra_params: Optional[Dict[str, Any]],
filters: Dict[str, Any],
) -> tuple[Optional[str], Any, Dict[str, Any], Dict[str, Any]]:
include_inactive = filters.pop("include_inactive", False)

extra_params = extra_params or {}
extra_params["includeInactive"] = str(include_inactive).lower()

return super()._prepare_list_params(study_key, refresh, extra_params, filters)
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_codings_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(codings.CodingsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(codings.CodingsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", "x")
4 changes: 2 additions & 2 deletions tests/unit/endpoints/test_endpoints_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def fake_impl(self, client, paginator, *, study_key=None, **filters):
called["filters"] = filters
return [Record(record_id=1)]

monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl)

rec = await ep.async_get("S1", 1)

Expand All @@ -181,7 +181,7 @@ async def test_async_get_record_not_found(monkeypatch, dummy_client, context, re
async def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(records.RecordsEndpoint, "_list_async", fake_impl)

with pytest.raises(ValueError):
await ep.async_get("S1", 1)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/endpoints/test_forms_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filte
called["filters"] = filters
return [Form(form_id=1)]

monkeypatch.setattr(forms.FormsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(forms.FormsEndpoint, "_list_sync", fake_impl)

res = ep.get("S1", 1)

Expand All @@ -48,7 +48,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(forms.FormsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(forms.FormsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_intervals_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(intervals.IntervalsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(intervals.IntervalsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_queries_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(queries.QueriesEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(queries.QueriesEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_record_revisions_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(record_revisions.RecordRevisionsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(record_revisions.RecordRevisionsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
4 changes: 2 additions & 2 deletions tests/unit/endpoints/test_records_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def fake_impl(self, client, paginator, *, study_key=None, **filters):
called["filters"] = filters
return [Record(record_id=1)]

monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(records.RecordsEndpoint, "_list_sync", fake_impl)

res = ep.get("S1", 1)

Expand All @@ -49,7 +49,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(records.RecordsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(records.RecordsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_sites_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(sites.SitesEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(sites.SitesEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_subjects_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(subjects.SubjectsEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(subjects.SubjectsEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", "X")
2 changes: 1 addition & 1 deletion tests/unit/endpoints/test_users_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_get_not_found(monkeypatch, dummy_client, context):
def fake_impl(self, client, paginator, *, study_key=None, refresh=False, **filters):
return []

monkeypatch.setattr(users.UsersEndpoint, "_list_impl", fake_impl)
monkeypatch.setattr(users.UsersEndpoint, "_list_sync", fake_impl)

with pytest.raises(ValueError):
ep.get("S1", 1)
Loading
Loading