From 13fc3874cff70a21bd2ea914fd0ef56e10f6c811 Mon Sep 17 00:00:00 2001 From: Vincent Maurin Date: Mon, 2 Feb 2026 11:22:41 +0100 Subject: [PATCH] fix: use the refresh code for oauth token In order for a channel to keep going, it needs to get a fresh token when needed at individual request level. The token fetched only once at channel creation where it should be fetched for each request. The CredentialHelper has to be moved to non asyncio as it seems to be full sync code in underlying grpc. --- docs/api/exceptions.rst | 2 +- docs/authentication.rst | 12 +- examples/classify_single_example.py | 9 -- examples/example.py | 9 -- src/resolver_athena_client/client/channel.py | 99 ++++++--------- tests/client/test_channel.py | 121 +++---------------- tests/functional/conftest.py | 11 +- 7 files changed, 59 insertions(+), 204 deletions(-) diff --git a/docs/api/exceptions.rst b/docs/api/exceptions.rst index 4c5720a..75419ee 100644 --- a/docs/api/exceptions.rst +++ b/docs/api/exceptions.rst @@ -129,7 +129,7 @@ OAuth Error Handling client_id=client_id, client_secret=client_secret ) - token = await credential_helper.get_token() + token = credential_helper.get_token() except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle OAuth failure - check credentials diff --git a/docs/authentication.rst b/docs/authentication.rst index 066c587..49d147f 100644 --- a/docs/authentication.rst +++ b/docs/authentication.rst @@ -76,14 +76,6 @@ Set these environment variables for OAuth authentication: audience=os.getenv("OAUTH_AUDIENCE", "crisp-athena-live"), ) - # Test token acquisition - try: - token = await credential_helper.get_token() - print(f"Successfully acquired token (length: {len(token)})") - except Exception as e: - print(f"Failed to acquire OAuth token: {e}") - return - # Create authenticated channel channel = await create_channel_with_credentials( host=os.getenv("ATHENA_HOST"), @@ -261,7 +253,7 @@ Handle OAuth-specific errors gracefully: from resolver_athena_client.client.exceptions import AuthenticationError try: - token = await credential_helper.get_token() + token = credential_helper.get_token() except AuthenticationError as e: logger.error(f"OAuth authentication failed: {e}") # Handle authentication failure @@ -356,7 +348,7 @@ Test your authentication setup: client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) - token = await credential_helper.get_token() + token = credential_helper.get_token() print(f"✓ Authentication successful (token length: {len(token)})") return True diff --git a/examples/classify_single_example.py b/examples/classify_single_example.py index f543bed..d11d0b9 100755 --- a/examples/classify_single_example.py +++ b/examples/classify_single_example.py @@ -214,15 +214,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Configure client options options = AthenaOptions( host=host, diff --git a/examples/example.py b/examples/example.py index e4e8220..4b43d7e 100755 --- a/examples/example.py +++ b/examples/example.py @@ -163,15 +163,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Get available deployment channel = await create_channel_with_credentials(host, credential_helper) async with DeploymentSelector(channel) as deployment_selector: diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index e9d1510..d8f192c 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -1,9 +1,8 @@ """Channel creation utilities for the Athena client.""" -import asyncio import json +import threading import time -from typing import override import grpc import httpx @@ -16,39 +15,6 @@ ) -class TokenMetadataPlugin(grpc.AuthMetadataPlugin): - """Plugin that adds authorization token to gRPC metadata.""" - - def __init__(self, token: str) -> None: - """Initialize the plugin with the auth token. - - Args: - ---- - token: The authorization token to add to requests - - """ - self._token: str = token - - @override - def __call__( - self, - _: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback, - ) -> None: - """Pass authentication metadata to the provided callback. - - This method will be invoked asynchronously in a separate thread. - - Args: - ---- - callback: An AuthMetadataPluginCallback to be invoked either - synchronously or asynchronously. - - """ - metadata = (("authorization", f"Token {self._token}"),) - callback(metadata, None) - - class CredentialHelper: """OAuth credential helper for managing authentication tokens.""" @@ -82,9 +48,9 @@ def __init__( self._audience: str = audience self._token: str | None = None self._token_expires_at: float | None = None - self._lock: asyncio.Lock = asyncio.Lock() + self._lock = threading.Lock() - async def get_token(self) -> str: + def get_token(self) -> str: """Get a valid authentication token. This method will return a cached token if it's still valid, @@ -97,20 +63,14 @@ async def get_token(self) -> str: Raises ------ OAuthError: If token acquisition fails - TokenExpiredError: If token has expired and refresh fails """ - async with self._lock: - if self._is_token_valid(): - if self._token is None: - msg = "Token should be valid but is None" - raise RuntimeError(msg) - return self._token - - await self._refresh_token() - if self._token is None: - msg = "Token refresh failed" - raise RuntimeError(msg) + if self._is_token_valid(): + return self._token + + with self._lock: + if not self._is_token_valid(): + self._refresh_token() return self._token def _is_token_valid(self) -> bool: @@ -127,7 +87,7 @@ def _is_token_valid(self) -> bool: # Add 30 second buffer before expiration return time.time() < (self._token_expires_at - 30) - async def _refresh_token(self) -> None: + def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. Raises @@ -145,8 +105,8 @@ async def _refresh_token(self) -> None: headers = {"content-type": "application/json"} try: - async with httpx.AsyncClient() as client: - response = await client.post( + with httpx.Client() as client: + response = client.post( self._auth_url, json=payload, headers=headers, @@ -154,12 +114,10 @@ async def _refresh_token(self) -> None: ) _ = response.raise_for_status() - token_data = response.json() - self._token = token_data["access_token"] - expires_in = token_data.get( - "expires_in", 3600 - ) # Default 1 hour - self._token_expires_at = time.time() + expires_in + token_data = response.json() + self._token = token_data["access_token"] + expires_in = token_data.get("expires_in", 3600) # Default 1 hour + self._token_expires_at = time.time() + expires_in except httpx.HTTPStatusError as e: error_detail = "" @@ -190,13 +148,29 @@ async def _refresh_token(self) -> None: msg = f"Unexpected error during OAuth: {e}" raise OAuthError(msg) from e - async def invalidate_token(self) -> None: + def invalidate_token(self) -> None: """Invalidate the current token to force a refresh on next use.""" - async with self._lock: + with self._lock: self._token = None self._token_expires_at = None +class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): + def __init__(self, credential_helper: CredentialHelper) -> None: + self._credential_helper = credential_helper + + def __call__( + self, + _: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + try: + token = self._credential_helper.get_token() + metadata = (("authorization", f"Bearer {token}"),) + callback(metadata, None) + except OAuthError as err: + callback(None, err) + async def create_channel_with_credentials( host: str, credential_helper: CredentialHelper, @@ -221,13 +195,10 @@ async def create_channel_with_credentials( if not host: raise InvalidHostError(InvalidHostError.default_message) - # Get a valid token from the credential helper - token = await credential_helper.get_token() - # Create credentials with token authentication credentials = grpc.composite_channel_credentials( grpc.ssl_channel_credentials(), - grpc.access_token_call_credentials(token), + grpc.metadata_call_credentials(_AutoRefreshTokenAuthMetadataPlugin(credential_helper)), ) # Configure gRPC options for persistent connections diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 701bd00..c0b2dac 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -7,11 +7,9 @@ import httpx import pytest -from grpc.aio import Channel from resolver_athena_client.client.channel import ( CredentialHelper, - TokenMetadataPlugin, create_channel_with_credentials, ) from resolver_athena_client.client.exceptions import ( @@ -21,23 +19,6 @@ ) -def test_token_metadata_plugin() -> None: - """Test TokenMetadataPlugin functionality.""" - test_token = "test-token" - plugin = TokenMetadataPlugin(test_token) - - # Mock callback - mock_callback = mock.Mock() - mock_context = mock.Mock() - - # Call the plugin - plugin(mock_context, mock_callback) - - # Verify the callback was called with correct metadata - expected_metadata = (("authorization", f"Token {test_token}"),) - mock_callback.assert_called_once_with(expected_metadata, None) - - @pytest.mark.asyncio async def test_create_channel_with_credentials_validation() -> None: """Test channel creation with credentials validates input properly.""" @@ -49,18 +30,6 @@ async def test_create_channel_with_credentials_validation() -> None: _ = await create_channel_with_credentials(test_host, mock_helper) -@pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_failure() -> None: - """Test channel creation when OAuth token acquisition fails.""" - test_host = "test-host:50051" - - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("Token acquisition failed") - - with pytest.raises(OAuthError, match="Token acquisition failed"): - _ = await create_channel_with_credentials(test_host, mock_helper) - - class TestCredentialHelper: """Test cases for CredentialHelper OAuth functionality.""" @@ -153,8 +122,7 @@ def test_is_token_valid_with_soon_expiring_token(self) -> None: assert not helper._is_token_valid() - @pytest.mark.asyncio - async def test_get_token_success(self) -> None: + def test_get_token_success(self) -> None: """Test successful token acquisition.""" helper = CredentialHelper( client_id="test_client_id", @@ -168,18 +136,17 @@ async def test_get_token_success(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response - token = await helper.get_token() + token = helper.get_token() assert token == "new_access_token" assert helper._token == "new_access_token" assert helper._token_expires_at is not None - @pytest.mark.asyncio - async def test_get_token_cached(self) -> None: + def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" helper = CredentialHelper( client_id="test_client_id", @@ -190,12 +157,11 @@ async def test_get_token_cached(self) -> None: helper._token = "cached_token" helper._token_expires_at = time.time() + 3600 - token = await helper.get_token() + token = helper.get_token() assert token == "cached_token" - @pytest.mark.asyncio - async def test_refresh_token_http_error(self) -> None: + def test_refresh_token_http_error(self) -> None: """Test token refresh with HTTP error.""" helper = CredentialHelper( client_id="test_client_id", @@ -215,17 +181,16 @@ async def test_refresh_token_http_error(self) -> None: response=mock_response, ) - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = http_error with pytest.raises( OAuthError, match="OAuth request failed with status 401" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_request_error(self) -> None: + def test_refresh_token_request_error(self) -> None: """Test token refresh with request error.""" helper = CredentialHelper( client_id="test_client_id", @@ -234,17 +199,16 @@ async def test_refresh_token_request_error(self) -> None: request_error = httpx.RequestError("Connection failed") - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = request_error with pytest.raises( OAuthError, match="Failed to connect to OAuth server" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_invalid_response(self) -> None: + def test_refresh_token_invalid_response(self) -> None: """Test token refresh with invalid response format.""" helper = CredentialHelper( client_id="test_client_id", @@ -257,17 +221,16 @@ async def test_refresh_token_invalid_response(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response with pytest.raises( OAuthError, match="Invalid OAuth response format" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_invalidate_token(self) -> None: + def test_invalidate_token(self) -> None: """Test token invalidation.""" helper = CredentialHelper( client_id="test_client_id", @@ -278,45 +241,12 @@ async def test_invalidate_token(self) -> None: helper._token = "valid_token" helper._token_expires_at = time.time() + 3600 - await helper.invalidate_token() + helper.invalidate_token() assert helper._token is None assert helper._token_expires_at is None -@pytest.mark.asyncio -async def test_create_channel_with_credentials_success() -> None: - """Test successful channel creation with credential helper.""" - test_host = "test-host:50051" - - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = "test_token" - - mock_credentials = mock.Mock() - mock_channel = mock.Mock(spec=Channel) - - with ( - mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, - mock.patch("grpc.access_token_call_credentials") as mock_token_creds, - mock.patch( - "grpc.composite_channel_credentials" - ) as mock_composite_creds, - mock.patch("grpc.aio.secure_channel") as mock_secure_channel, - ): - # Set up mocks - mock_ssl_creds.return_value = mock.Mock() - mock_token_creds.return_value = mock.Mock() - mock_composite_creds.return_value = mock_credentials - mock_secure_channel.return_value = mock_channel - - # Create channel - channel = await create_channel_with_credentials(test_host, mock_helper) - - # Verify channel creation - assert channel == mock_channel - mock_helper.get_token.assert_called_once() - mock_token_creds.assert_called_once_with("test_token") - @pytest.mark.asyncio async def test_create_channel_with_credentials_invalid_host() -> None: @@ -328,14 +258,3 @@ async def test_create_channel_with_credentials_invalid_host() -> None: with pytest.raises(InvalidHostError, match="host cannot be empty"): _ = await create_channel_with_credentials(test_host, mock_helper) - -@pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_error() -> None: - """Test channel creation with credentials when OAuth fails.""" - test_host = "test-host:50051" - - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("OAuth failed") - - with pytest.raises(OAuthError, match="OAuth failed"): - _ = await create_channel_with_credentials(test_host, mock_helper) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index e7cd324..6690b40 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -46,22 +46,13 @@ async def credential_helper() -> CredentialHelper: audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-live") # Create credential helper - credential_helper = CredentialHelper( + return CredentialHelper( client_id=client_id, client_secret=client_secret, auth_url=auth_url, audience=audience, ) - # Test token acquisition - try: - _ = await credential_helper.get_token() - except Exception as e: - msg = "Failed to acquire OAuth token" - raise AssertionError(msg) from e - - return credential_helper - @pytest.fixture def athena_options() -> AthenaOptions: