Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions stapi-fastapi/src/stapi_fastapi/backends/root_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
OrderStatus,
)

T = TypeVar("T", bound=OrderStatus)

GetOrders = Callable[
[str | None, int, Request],
Coroutine[Any, Any, ResultE[tuple[list[Order[OrderStatus]], Maybe[str], Maybe[int]]]],
Coroutine[Any, Any, ResultE[tuple[list[Order[T]], Maybe[str], Maybe[int]]]],
]
"""
Type alias for an async function that returns a list of existing Orders.
Expand Down Expand Up @@ -48,9 +50,6 @@
"""


T = TypeVar("T", bound=OrderStatus)


GetOrderStatuses = Callable[
[str, str | None, int, Request],
Coroutine[Any, Any, ResultE[Maybe[tuple[list[T], Maybe[str]]]]],
Expand Down
9 changes: 6 additions & 3 deletions stapi-fastapi/src/stapi_fastapi/routers/product_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import traceback
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

from fastapi import (
Depends,
Expand Down Expand Up @@ -68,7 +68,10 @@ def get_prefer(prefer: str | None = Header(None)) -> str | None:
return Prefer(prefer)


def build_conformances(product: Product, root_router: RootRouter) -> list[str]:
T = TypeVar("T", bound=OrderStatus)


def build_conformances(product: Product, root_router: RootRouter[T]) -> list[str]:
# FIXME we can make this check more robust
if not any(conformance.startswith("https://geojson.org/schema/") for conformance in product.conformsTo):
raise ValueError("product conformance does not contain at least one geojson conformance")
Expand All @@ -90,7 +93,7 @@ class ProductRouter(StapiFastapiBaseRouter):
def __init__( # noqa
self,
product: Product,
root_router: RootRouter,
root_router: RootRouter[T],
*args: Any,
**kwargs: Any,
) -> None:
Expand Down
20 changes: 11 additions & 9 deletions stapi-fastapi/src/stapi_fastapi/routers/root_router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import traceback
from typing import Any
from typing import Any, Generic, TypeVar

from fastapi import HTTPException, Request, status
from fastapi.datastructures import URL
Expand Down Expand Up @@ -50,11 +50,13 @@

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=OrderStatus)

class RootRouter(StapiFastapiBaseRouter):

class RootRouter(StapiFastapiBaseRouter, Generic[T]):
def __init__(
self,
get_orders: GetOrders,
get_orders: GetOrders[T],
get_order: GetOrder,
get_order_statuses: GetOrderStatuses | None = None, # type: ignore
get_opportunity_search_records: GetOpportunitySearchRecords | None = None,
Expand Down Expand Up @@ -240,7 +242,7 @@ def get_products(self, request: Request, next: str | None = None, limit: int = 1

async def get_orders( # noqa: C901
self, request: Request, next: str | None = None, limit: int = 10
) -> OrderCollection[OrderStatus]:
) -> OrderCollection[T]:
links: list[Link] = []
orders_count: int | None = None
match await self._get_orders(next, limit, request):
Expand Down Expand Up @@ -271,13 +273,13 @@ async def get_orders( # noqa: C901
case _:
raise AssertionError("Expected code to be unreachable")

return OrderCollection(
return OrderCollection[T](
features=orders,
links=links,
number_matched=orders_count,
)

async def get_order(self, order_id: str, request: Request) -> Order[OrderStatus]:
async def get_order(self, order_id: str, request: Request) -> Order[T]:
"""
Get details for order with `order_id`.
"""
Expand Down Expand Up @@ -306,7 +308,7 @@ async def get_order_statuses(
request: Request,
next: str | None = None,
limit: int = 10,
) -> OrderStatuses: # type: ignore
) -> OrderStatuses[T]:
links: list[Link] = []
match await self._get_order_statuses(order_id, next, limit, request):
case Success(Some((statuses, maybe_pagination_token))):
Expand Down Expand Up @@ -350,7 +352,7 @@ def generate_order_href(self, request: Request, order_id: str) -> URL:
def generate_order_statuses_href(self, request: Request, order_id: str) -> URL:
return self.url_for(request, f"{self.name}:{LIST_ORDER_STATUSES}", order_id=order_id)

def order_links(self, order: Order[OrderStatus], request: Request) -> list[Link]:
def order_links(self, order: Order[T], request: Request) -> list[Link]:
return [
Link(
href=self.generate_order_href(request, order.id),
Expand Down Expand Up @@ -464,7 +466,7 @@ def opportunity_search_record_self_link(
return json_link("self", self.generate_opportunity_search_record_href(request, opportunity_search_record.id))

@property
def _get_order_statuses(self) -> GetOrderStatuses: # type: ignore
def _get_order_statuses(self) -> GetOrderStatuses[T]:
if not self.__get_order_statuses:
raise AttributeError("Root router does not support order status history")
return self.__get_order_statuses
Expand Down
Loading