Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api/exceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ OAuth Error Handling
client_id=client_id,
client_secret=client_secret
)
token = await credential_helper.get_token()
token = credential_helper.get_token()
except OAuthError as e:
logger.error(f"OAuth authentication failed: {e}")
# Handle OAuth failure - check credentials
Expand Down
12 changes: 2 additions & 10 deletions docs/authentication.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,6 @@ Set these environment variables for OAuth authentication:
audience=os.getenv("OAUTH_AUDIENCE", "crisp-athena-live"),
)

# Test token acquisition
try:
token = await credential_helper.get_token()
print(f"Successfully acquired token (length: {len(token)})")
except Exception as e:
print(f"Failed to acquire OAuth token: {e}")
return

# Create authenticated channel
channel = await create_channel_with_credentials(
host=os.getenv("ATHENA_HOST"),
Expand Down Expand Up @@ -261,7 +253,7 @@ Handle OAuth-specific errors gracefully:
from resolver_athena_client.client.exceptions import AuthenticationError

try:
token = await credential_helper.get_token()
token = credential_helper.get_token()
except AuthenticationError as e:
logger.error(f"OAuth authentication failed: {e}")
# Handle authentication failure
Expand Down Expand Up @@ -356,7 +348,7 @@ Test your authentication setup:
client_secret=os.getenv("OAUTH_CLIENT_SECRET"),
)

token = await credential_helper.get_token()
token = credential_helper.get_token()
print(f"✓ Authentication successful (token length: {len(token)})")
return True

Expand Down
9 changes: 0 additions & 9 deletions examples/classify_single_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,6 @@ async def main() -> int:
audience=audience,
)

# Test token acquisition
try:
logger.info("Acquiring OAuth token...")
token = await credential_helper.get_token()
logger.info("Successfully acquired token (length: %d)", len(token))
except Exception:
logger.exception("Failed to acquire OAuth token")
return 1

# Configure client options
options = AthenaOptions(
host=host,
Expand Down
9 changes: 0 additions & 9 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,15 +163,6 @@ async def main() -> int:
audience=audience,
)

# Test token acquisition
try:
logger.info("Acquiring OAuth token...")
token = await credential_helper.get_token()
logger.info("Successfully acquired token (length: %d)", len(token))
except Exception:
logger.exception("Failed to acquire OAuth token")
return 1

# Get available deployment
channel = await create_channel_with_credentials(host, credential_helper)
async with DeploymentSelector(channel) as deployment_selector:
Expand Down
99 changes: 35 additions & 64 deletions src/resolver_athena_client/client/channel.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Channel creation utilities for the Athena client."""

import asyncio
import json
import threading
import time
from typing import override

import grpc
import httpx
Expand All @@ -16,39 +15,6 @@
)


class TokenMetadataPlugin(grpc.AuthMetadataPlugin):
"""Plugin that adds authorization token to gRPC metadata."""

def __init__(self, token: str) -> None:
"""Initialize the plugin with the auth token.

Args:
----
token: The authorization token to add to requests

"""
self._token: str = token

@override
def __call__(
self,
_: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
"""Pass authentication metadata to the provided callback.

This method will be invoked asynchronously in a separate thread.

Args:
----
callback: An AuthMetadataPluginCallback to be invoked either
synchronously or asynchronously.

"""
metadata = (("authorization", f"Token {self._token}"),)
callback(metadata, None)


class CredentialHelper:
"""OAuth credential helper for managing authentication tokens."""

Expand Down Expand Up @@ -82,9 +48,9 @@ def __init__(
self._audience: str = audience
self._token: str | None = None
self._token_expires_at: float | None = None
self._lock: asyncio.Lock = asyncio.Lock()
self._lock = threading.Lock()

async def get_token(self) -> str:
def get_token(self) -> str:
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_token() is annotated to return str but can return None (e.g., if _is_token_valid() is true while _token is None, or if _token is cleared between the validity check and the return). Make the return value guaranteed by (a) performing the validity check + read under the same lock, and (b) raising a clear error if the token is unexpectedly missing after a 'valid' check / refresh.

Copilot uses AI. Check for mistakes.
"""Get a valid authentication token.

This method will return a cached token if it's still valid,
Expand All @@ -97,20 +63,14 @@ async def get_token(self) -> str:
Raises
------
OAuthError: If token acquisition fails
TokenExpiredError: If token has expired and refresh fails

"""
async with self._lock:
if self._is_token_valid():
if self._token is None:
msg = "Token should be valid but is None"
raise RuntimeError(msg)
return self._token

await self._refresh_token()
if self._token is None:
msg = "Token refresh failed"
raise RuntimeError(msg)
if self._is_token_valid():
return self._token

with self._lock:
if not self._is_token_valid():
self._refresh_token()
return self._token
Comment on lines +68 to 74
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_token() is annotated to return str but can return None (e.g., if _is_token_valid() is true while _token is None, or if _token is cleared between the validity check and the return). Make the return value guaranteed by (a) performing the validity check + read under the same lock, and (b) raising a clear error if the token is unexpectedly missing after a 'valid' check / refresh.

Copilot uses AI. Check for mistakes.
Comment on lines +68 to 74
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s a race between the unlocked fast-path in get_token() and invalidate_token(): another thread can invalidate between the _is_token_valid() check and return self._token, causing a None return. Remove the unlocked fast-path and do the check/refresh/return entirely inside the lock (or copy the token to a local variable while holding the lock).

Copilot uses AI. Check for mistakes.

def _is_token_valid(self) -> bool:
Expand All @@ -127,7 +87,7 @@ def _is_token_valid(self) -> bool:
# Add 30 second buffer before expiration
return time.time() < (self._token_expires_at - 30)

async def _refresh_token(self) -> None:
def _refresh_token(self) -> None:
"""Refresh the authentication token by making an OAuth request.

Raises
Expand All @@ -145,21 +105,19 @@ async def _refresh_token(self) -> None:
headers = {"content-type": "application/json"}

try:
async with httpx.AsyncClient() as client:
response = await client.post(
with httpx.Client() as client:
response = client.post(
self._auth_url,
json=payload,
headers=headers,
timeout=30.0,
)
_ = response.raise_for_status()

token_data = response.json()
self._token = token_data["access_token"]
expires_in = token_data.get(
"expires_in", 3600
) # Default 1 hour
self._token_expires_at = time.time() + expires_in
token_data = response.json()
self._token = token_data["access_token"]
expires_in = token_data.get("expires_in", 3600) # Default 1 hour
self._token_expires_at = time.time() + expires_in

except httpx.HTTPStatusError as e:
error_detail = ""
Expand Down Expand Up @@ -190,13 +148,29 @@ async def _refresh_token(self) -> None:
msg = f"Unexpected error during OAuth: {e}"
raise OAuthError(msg) from e

async def invalidate_token(self) -> None:
def invalidate_token(self) -> None:
"""Invalidate the current token to force a refresh on next use."""
async with self._lock:
with self._lock:
self._token = None
self._token_expires_at = None
Comment on lines +153 to 155
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There’s a race between the unlocked fast-path in get_token() and invalidate_token(): another thread can invalidate between the _is_token_valid() check and return self._token, causing a None return. Remove the unlocked fast-path and do the check/refresh/return entirely inside the lock (or copy the token to a local variable while holding the lock).

Copilot uses AI. Check for mistakes.


class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin):
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New per-request auth behavior is introduced via _AutoRefreshTokenAuthMetadataPlugin, but the PR removes the previous unit test coverage around auth metadata injection. Add focused tests that (1) the plugin calls credential_helper.get_token() and passes the expected authorization metadata to the callback, and (2) an OAuthError results in callback(None, err).

Copilot uses AI. Check for mistakes.
def __init__(self, credential_helper: CredentialHelper) -> None:
self._credential_helper = credential_helper

def __call__(
self,
_: grpc.AuthMetadataContext,
callback: grpc.AuthMetadataPluginCallback,
) -> None:
try:
token = self._credential_helper.get_token()
metadata = (("authorization", f"Bearer {token}"),)
callback(metadata, None)
except OAuthError as err:
callback(None, err)

async def create_channel_with_credentials(
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_channel_with_credentials() no longer acquires a token during channel creation, so it likely won’t raise OAuthError at creation time anymore (OAuth failures will occur during RPC auth metadata injection). Update the function docstring/raises section to reflect when/where OAuth errors are expected now.

Copilot uses AI. Check for mistakes.
host: str,
credential_helper: CredentialHelper,
Expand All @@ -221,13 +195,10 @@ async def create_channel_with_credentials(
if not host:
raise InvalidHostError(InvalidHostError.default_message)

# Get a valid token from the credential helper
token = await credential_helper.get_token()

# Create credentials with token authentication
credentials = grpc.composite_channel_credentials(
grpc.ssl_channel_credentials(),
grpc.access_token_call_credentials(token),
grpc.metadata_call_credentials(_AutoRefreshTokenAuthMetadataPlugin(credential_helper)),
)

# Configure gRPC options for persistent connections
Expand Down
Loading
Loading