-
Notifications
You must be signed in to change notification settings - Fork 1
fix: use the refresh code for oauth token #88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,8 @@ | ||
| """Channel creation utilities for the Athena client.""" | ||
|
|
||
| import asyncio | ||
| import json | ||
| import threading | ||
| import time | ||
| from typing import override | ||
|
|
||
| import grpc | ||
| import httpx | ||
|
|
@@ -16,39 +15,6 @@ | |
| ) | ||
|
|
||
|
|
||
| class TokenMetadataPlugin(grpc.AuthMetadataPlugin): | ||
| """Plugin that adds authorization token to gRPC metadata.""" | ||
|
|
||
| def __init__(self, token: str) -> None: | ||
| """Initialize the plugin with the auth token. | ||
|
|
||
| Args: | ||
| ---- | ||
| token: The authorization token to add to requests | ||
|
|
||
| """ | ||
| self._token: str = token | ||
|
|
||
| @override | ||
| def __call__( | ||
| self, | ||
| _: grpc.AuthMetadataContext, | ||
| callback: grpc.AuthMetadataPluginCallback, | ||
| ) -> None: | ||
| """Pass authentication metadata to the provided callback. | ||
|
|
||
| This method will be invoked asynchronously in a separate thread. | ||
|
|
||
| Args: | ||
| ---- | ||
| callback: An AuthMetadataPluginCallback to be invoked either | ||
| synchronously or asynchronously. | ||
|
|
||
| """ | ||
| metadata = (("authorization", f"Token {self._token}"),) | ||
| callback(metadata, None) | ||
|
|
||
|
|
||
| class CredentialHelper: | ||
| """OAuth credential helper for managing authentication tokens.""" | ||
|
|
||
|
|
@@ -82,9 +48,9 @@ def __init__( | |
| self._audience: str = audience | ||
| self._token: str | None = None | ||
| self._token_expires_at: float | None = None | ||
| self._lock: asyncio.Lock = asyncio.Lock() | ||
| self._lock = threading.Lock() | ||
|
|
||
| async def get_token(self) -> str: | ||
| def get_token(self) -> str: | ||
| """Get a valid authentication token. | ||
|
|
||
| This method will return a cached token if it's still valid, | ||
|
|
@@ -97,20 +63,14 @@ async def get_token(self) -> str: | |
| Raises | ||
| ------ | ||
| OAuthError: If token acquisition fails | ||
| TokenExpiredError: If token has expired and refresh fails | ||
|
|
||
| """ | ||
| async with self._lock: | ||
| if self._is_token_valid(): | ||
| if self._token is None: | ||
| msg = "Token should be valid but is None" | ||
| raise RuntimeError(msg) | ||
| return self._token | ||
|
|
||
| await self._refresh_token() | ||
| if self._token is None: | ||
| msg = "Token refresh failed" | ||
| raise RuntimeError(msg) | ||
| if self._is_token_valid(): | ||
| return self._token | ||
|
|
||
| with self._lock: | ||
| if not self._is_token_valid(): | ||
| self._refresh_token() | ||
| return self._token | ||
|
Comment on lines
+68
to
74
|
||
|
|
||
| def _is_token_valid(self) -> bool: | ||
|
|
@@ -127,7 +87,7 @@ def _is_token_valid(self) -> bool: | |
| # Add 30 second buffer before expiration | ||
| return time.time() < (self._token_expires_at - 30) | ||
|
|
||
| async def _refresh_token(self) -> None: | ||
| def _refresh_token(self) -> None: | ||
| """Refresh the authentication token by making an OAuth request. | ||
|
|
||
| Raises | ||
|
|
@@ -145,21 +105,19 @@ async def _refresh_token(self) -> None: | |
| headers = {"content-type": "application/json"} | ||
|
|
||
| try: | ||
| async with httpx.AsyncClient() as client: | ||
| response = await client.post( | ||
| with httpx.Client() as client: | ||
| response = client.post( | ||
| self._auth_url, | ||
| json=payload, | ||
| headers=headers, | ||
| timeout=30.0, | ||
| ) | ||
| _ = response.raise_for_status() | ||
|
|
||
| token_data = response.json() | ||
| self._token = token_data["access_token"] | ||
| expires_in = token_data.get( | ||
| "expires_in", 3600 | ||
| ) # Default 1 hour | ||
| self._token_expires_at = time.time() + expires_in | ||
| token_data = response.json() | ||
| self._token = token_data["access_token"] | ||
| expires_in = token_data.get("expires_in", 3600) # Default 1 hour | ||
| self._token_expires_at = time.time() + expires_in | ||
|
|
||
| except httpx.HTTPStatusError as e: | ||
| error_detail = "" | ||
|
|
@@ -190,13 +148,29 @@ async def _refresh_token(self) -> None: | |
| msg = f"Unexpected error during OAuth: {e}" | ||
| raise OAuthError(msg) from e | ||
|
|
||
| async def invalidate_token(self) -> None: | ||
| def invalidate_token(self) -> None: | ||
| """Invalidate the current token to force a refresh on next use.""" | ||
| async with self._lock: | ||
| with self._lock: | ||
| self._token = None | ||
| self._token_expires_at = None | ||
|
Comment on lines
+153
to
155
|
||
|
|
||
|
|
||
| class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): | ||
|
||
| def __init__(self, credential_helper: CredentialHelper) -> None: | ||
| self._credential_helper = credential_helper | ||
|
|
||
| def __call__( | ||
| self, | ||
| _: grpc.AuthMetadataContext, | ||
| callback: grpc.AuthMetadataPluginCallback, | ||
| ) -> None: | ||
| try: | ||
| token = self._credential_helper.get_token() | ||
| metadata = (("authorization", f"Bearer {token}"),) | ||
| callback(metadata, None) | ||
| except OAuthError as err: | ||
| callback(None, err) | ||
snus-kin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| async def create_channel_with_credentials( | ||
|
||
| host: str, | ||
| credential_helper: CredentialHelper, | ||
|
|
@@ -221,13 +195,10 @@ async def create_channel_with_credentials( | |
| if not host: | ||
| raise InvalidHostError(InvalidHostError.default_message) | ||
|
|
||
| # Get a valid token from the credential helper | ||
| token = await credential_helper.get_token() | ||
|
|
||
| # Create credentials with token authentication | ||
| credentials = grpc.composite_channel_credentials( | ||
| grpc.ssl_channel_credentials(), | ||
| grpc.access_token_call_credentials(token), | ||
| grpc.metadata_call_credentials(_AutoRefreshTokenAuthMetadataPlugin(credential_helper)), | ||
| ) | ||
snus-kin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Configure gRPC options for persistent connections | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_token()is annotated to returnstrbut can returnNone(e.g., if_is_token_valid()is true while_tokenisNone, or if_tokenis cleared between the validity check and the return). Make the return value guaranteed by (a) performing the validity check + read under the same lock, and (b) raising a clear error if the token is unexpectedly missing after a 'valid' check / refresh.