Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions app/api/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,12 @@ def get_bearer_token(
return credentials.credentials


def require_authentication(
async def require_authentication(
token: str = Depends(get_bearer_token),
client: AuthServiceClient = Depends(get_auth_service_client),
) -> AuthContext:
try:
claims = client.validate_token(token)
claims = await client.validate_token(token)
except AuthServiceError as exc:
headers = (
{"WWW-Authenticate": "Bearer"}
Expand All @@ -50,12 +50,12 @@ def require_authentication(


def require_authorization(permission: str):
def dependency(
async def dependency(
context: AuthContext = Depends(require_authentication),
client: AuthServiceClient = Depends(get_auth_service_client),
) -> AuthContext:
try:
client.authorize(context.token, permission)
await client.authorize(context.token, permission)
except AuthServiceError as exc:
headers = (
{"WWW-Authenticate": "Bearer"}
Expand Down
25 changes: 15 additions & 10 deletions app/api/routes/gateway.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from anyio import to_thread
from fastapi import APIRouter, Depends, status

from app.api.deps import require_authorization
Expand All @@ -13,35 +14,39 @@


@router.get("/routes", response_model=list[RouteConfig])
def list_routes(store: RouteStore = Depends(get_route_store)) -> list[RouteConfig]:
return store.list_routes()
async def list_routes(
store: RouteStore = Depends(get_route_store),
) -> list[RouteConfig]:
return await to_thread.run_sync(store.list_routes)


@router.get("/routes/{route_id}", response_model=RouteConfig)
def get_route(
async def get_route(
route_id: str, store: RouteStore = Depends(get_route_store)
) -> RouteConfig:
return store.get_route(route_id)
return await to_thread.run_sync(store.get_route, route_id)


@router.post(
"/routes",
response_model=RouteConfig,
status_code=status.HTTP_201_CREATED,
)
def create_route(
async def create_route(
payload: RouteCreate, store: RouteStore = Depends(get_route_store)
) -> RouteConfig:
return store.create_route(payload)
return await to_thread.run_sync(store.create_route, payload)


@router.put("/routes/{route_id}", response_model=RouteConfig)
def update_route(
async def update_route(
route_id: str, payload: RouteUpdate, store: RouteStore = Depends(get_route_store)
) -> RouteConfig:
return store.update_route(route_id, payload)
return await to_thread.run_sync(store.update_route, route_id, payload)


@router.delete("/routes/{route_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_route(route_id: str, store: RouteStore = Depends(get_route_store)) -> None:
store.delete_route(route_id)
async def delete_route(
route_id: str, store: RouteStore = Depends(get_route_store)
) -> None:
await to_thread.run_sync(store.delete_route, route_id)
2 changes: 1 addition & 1 deletion app/api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@


@router.get("/health")
def health_check() -> dict:
async def health_check() -> dict:
return {"status": "ok"}
11 changes: 6 additions & 5 deletions app/api/routes/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from urllib.parse import urlsplit, urlunsplit

from anyio import to_thread
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import Response
Expand Down Expand Up @@ -42,13 +43,13 @@ async def proxy_request(
auth_client: AuthServiceClient = Depends(get_auth_service_client),
) -> Response:
raw_path = f"/{path}" if path else "/"
route = store.match_route(request.method, raw_path)
route = await to_thread.run_sync(store.match_route, request.method, raw_path)
if not route:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="route_not_found"
)

_maybe_authorize(route, request, auth_client)
await _maybe_authorize(route, request, auth_client)

upstream_path = _normalize_path(route.rewrite_path(raw_path))
upstream_url = _build_upstream_url(route.upstream_base_url, upstream_path)
Expand Down Expand Up @@ -81,7 +82,7 @@ async def proxy_request(
)


def _maybe_authorize(
async def _maybe_authorize(
route: RouteConfig, request: Request, auth_client: AuthServiceClient
) -> dict:
auth_config = route.auth or RouteAuth()
Expand All @@ -91,9 +92,9 @@ def _maybe_authorize(

token = _extract_bearer_token(request)
try:
claims = auth_client.validate_token(token)
claims = await auth_client.validate_token(token)
if auth_config.permission:
auth_client.authorize(token, auth_config.permission)
await auth_client.authorize(token, auth_config.permission)
if auth_config.roles:
_ensure_roles(claims, auth_config.roles)
except AuthServiceError as exc:
Expand Down
16 changes: 8 additions & 8 deletions app/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,16 @@ def __init__(
timeout: float,
validate_path: str,
authorize_path: str,
transport: httpx.BaseTransport | None = None,
transport: httpx.AsyncBaseTransport | None = None,
) -> None:
self._base_url = base_url.rstrip("/")
self._timeout = timeout
self._validate_path = _normalize_path(validate_path)
self._authorize_path = _normalize_path(authorize_path)
self._transport = transport

def validate_token(self, token: str) -> dict[str, Any]:
response = self._request("POST", self._validate_path, token, payload=None)
async def validate_token(self, token: str) -> dict[str, Any]:
response = await self._request("POST", self._validate_path, token, payload=None)
if response.status_code == status.HTTP_200_OK:
return self._parse_claims(response)
if response.status_code in (
Expand All @@ -49,8 +49,8 @@ def validate_token(self, token: str) -> dict[str, Any]:
raise AuthServiceError(response.status_code, "invalid_token")
raise AuthServiceError(status.HTTP_502_BAD_GATEWAY, "auth_service_error")

def authorize(self, token: str, permission: str) -> None:
response = self._request(
async def authorize(self, token: str, permission: str) -> None:
response = await self._request(
"POST",
self._authorize_path,
token,
Expand All @@ -67,7 +67,7 @@ def authorize(self, token: str, permission: str) -> None:
raise AuthServiceError(status.HTTP_403_FORBIDDEN, "not_authorized")
raise AuthServiceError(status.HTTP_502_BAD_GATEWAY, "auth_service_error")

def _request(
async def _request(
self,
method: str,
path: str,
Expand All @@ -76,12 +76,12 @@ def _request(
) -> httpx.Response:
headers = {"Authorization": f"Bearer {token}"}
try:
with httpx.Client(
async with httpx.AsyncClient(
base_url=self._base_url,
timeout=self._timeout,
transport=self._transport,
) as client:
return client.request(method, path, headers=headers, json=payload)
return await client.request(method, path, headers=headers, json=payload)
except httpx.RequestError as exc:
raise AuthServiceError(
status.HTTP_503_SERVICE_UNAVAILABLE,
Expand Down
10 changes: 5 additions & 5 deletions app/tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ def _build_app(client: AuthServiceClient) -> FastAPI:
app.dependency_overrides[get_auth_service_client] = lambda: client

@app.get("/protected")
def protected(
async def protected(
context: AuthContext = Depends(require_authentication),
) -> dict[str, Any]:
return {"subject": context.claims.get("sub")}

@app.get("/admin")
def admin(
async def admin(
context: AuthContext = Depends(require_authorization("admin:read")),
) -> dict[str, Any]:
return {"ok": True}
Expand All @@ -39,7 +39,7 @@ def _json_payload(request: httpx.Request) -> dict[str, Any]:


def _mock_transport() -> httpx.MockTransport:
def handler(request: httpx.Request) -> httpx.Response:
async def handler(request: httpx.Request) -> httpx.Response:
auth_header = request.headers.get("Authorization", "")
token = auth_header.replace("Bearer ", "", 1)

Expand All @@ -64,7 +64,7 @@ def handler(request: httpx.Request) -> httpx.Response:
return httpx.MockTransport(handler)


def _auth_client(transport: httpx.BaseTransport) -> AuthServiceClient:
def _auth_client(transport: httpx.AsyncBaseTransport) -> AuthServiceClient:
return AuthServiceClient(
base_url="http://auth-service.local",
timeout=1.0,
Expand Down Expand Up @@ -122,7 +122,7 @@ def test_authorization_allows_admin() -> None:


def test_auth_service_unavailable_returns_503() -> None:
def handler(request: httpx.Request) -> httpx.Response:
async def handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("boom", request=request)

app = _build_app(_auth_client(httpx.MockTransport(handler)))
Expand Down
4 changes: 2 additions & 2 deletions app/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


def _auth_client() -> AuthServiceClient:
def handler(request: httpx.Request) -> httpx.Response:
async def handler(request: httpx.Request) -> httpx.Response:
auth_header = request.headers.get("Authorization", "")
token = auth_header.replace("Bearer ", "", 1)

Expand Down Expand Up @@ -42,7 +42,7 @@ def handler(request: httpx.Request) -> httpx.Response:


def _proxy_client(record: dict) -> ProxyClient:
def handler(request: httpx.Request) -> httpx.Response:
async def handler(request: httpx.Request) -> httpx.Response:
record["path"] = request.url.path
record["query"] = request.url.query
return httpx.Response(200, json={"proxied": True})
Expand Down
2 changes: 1 addition & 1 deletion app/tests/test_routes_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _auth_client() -> AuthServiceClient:
def handler(request: httpx.Request) -> httpx.Response:
async def handler(request: httpx.Request) -> httpx.Response:
auth_header = request.headers.get("Authorization", "")
token = auth_header.replace("Bearer ", "", 1)

Expand Down