diff --git a/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py b/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py index f13e5cc..95ee48f 100644 --- a/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py +++ b/stapi-fastapi/src/stapi_fastapi/backends/root_backend.py @@ -11,9 +11,11 @@ OrderStatus, ) +T = TypeVar("T", bound=OrderStatus) + GetOrders = Callable[ [str | None, int, Request], - Coroutine[Any, Any, ResultE[tuple[list[Order[OrderStatus]], Maybe[str], Maybe[int]]]], + Coroutine[Any, Any, ResultE[tuple[list[Order[T]], Maybe[str], Maybe[int]]]], ] """ Type alias for an async function that returns a list of existing Orders. @@ -48,9 +50,6 @@ """ -T = TypeVar("T", bound=OrderStatus) - - GetOrderStatuses = Callable[ [str, str | None, int, Request], Coroutine[Any, Any, ResultE[Maybe[tuple[list[T], Maybe[str]]]]], diff --git a/stapi-fastapi/src/stapi_fastapi/routers/product_router.py b/stapi-fastapi/src/stapi_fastapi/routers/product_router.py index 430ae00..8c47c3d 100644 --- a/stapi-fastapi/src/stapi_fastapi/routers/product_router.py +++ b/stapi-fastapi/src/stapi_fastapi/routers/product_router.py @@ -2,7 +2,7 @@ import logging import traceback -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar from fastapi import ( Depends, @@ -68,7 +68,10 @@ def get_prefer(prefer: str | None = Header(None)) -> str | None: return Prefer(prefer) -def build_conformances(product: Product, root_router: RootRouter) -> list[str]: +T = TypeVar("T", bound=OrderStatus) + + +def build_conformances(product: Product, root_router: RootRouter[T]) -> list[str]: # FIXME we can make this check more robust if not any(conformance.startswith("https://geojson.org/schema/") for conformance in product.conformsTo): raise ValueError("product conformance does not contain at least one geojson conformance") @@ -90,7 +93,7 @@ class ProductRouter(StapiFastapiBaseRouter): def __init__( # noqa self, product: Product, - root_router: RootRouter, + root_router: RootRouter[T], *args: Any, **kwargs: Any, ) -> None: diff --git a/stapi-fastapi/src/stapi_fastapi/routers/root_router.py b/stapi-fastapi/src/stapi_fastapi/routers/root_router.py index c33abc1..10bf3c2 100644 --- a/stapi-fastapi/src/stapi_fastapi/routers/root_router.py +++ b/stapi-fastapi/src/stapi_fastapi/routers/root_router.py @@ -1,6 +1,6 @@ import logging import traceback -from typing import Any +from typing import Any, Generic, TypeVar from fastapi import HTTPException, Request, status from fastapi.datastructures import URL @@ -50,11 +50,13 @@ logger = logging.getLogger(__name__) +T = TypeVar("T", bound=OrderStatus) -class RootRouter(StapiFastapiBaseRouter): + +class RootRouter(StapiFastapiBaseRouter, Generic[T]): def __init__( self, - get_orders: GetOrders, + get_orders: GetOrders[T], get_order: GetOrder, get_order_statuses: GetOrderStatuses | None = None, # type: ignore get_opportunity_search_records: GetOpportunitySearchRecords | None = None, @@ -240,7 +242,7 @@ def get_products(self, request: Request, next: str | None = None, limit: int = 1 async def get_orders( # noqa: C901 self, request: Request, next: str | None = None, limit: int = 10 - ) -> OrderCollection[OrderStatus]: + ) -> OrderCollection[T]: links: list[Link] = [] orders_count: int | None = None match await self._get_orders(next, limit, request): @@ -271,13 +273,13 @@ async def get_orders( # noqa: C901 case _: raise AssertionError("Expected code to be unreachable") - return OrderCollection( + return OrderCollection[T]( features=orders, links=links, number_matched=orders_count, ) - async def get_order(self, order_id: str, request: Request) -> Order[OrderStatus]: + async def get_order(self, order_id: str, request: Request) -> Order[T]: """ Get details for order with `order_id`. """ @@ -306,7 +308,7 @@ async def get_order_statuses( request: Request, next: str | None = None, limit: int = 10, - ) -> OrderStatuses: # type: ignore + ) -> OrderStatuses[T]: links: list[Link] = [] match await self._get_order_statuses(order_id, next, limit, request): case Success(Some((statuses, maybe_pagination_token))): @@ -350,7 +352,7 @@ def generate_order_href(self, request: Request, order_id: str) -> URL: def generate_order_statuses_href(self, request: Request, order_id: str) -> URL: return self.url_for(request, f"{self.name}:{LIST_ORDER_STATUSES}", order_id=order_id) - def order_links(self, order: Order[OrderStatus], request: Request) -> list[Link]: + def order_links(self, order: Order[T], request: Request) -> list[Link]: return [ Link( href=self.generate_order_href(request, order.id), @@ -464,7 +466,7 @@ def opportunity_search_record_self_link( return json_link("self", self.generate_opportunity_search_record_href(request, opportunity_search_record.id)) @property - def _get_order_statuses(self) -> GetOrderStatuses: # type: ignore + def _get_order_statuses(self) -> GetOrderStatuses[T]: if not self.__get_order_statuses: raise AttributeError("Root router does not support order status history") return self.__get_order_statuses