diff --git a/docs/api/exceptions.rst b/docs/api/exceptions.rst index 4c5720a..75419ee 100644 --- a/docs/api/exceptions.rst +++ b/docs/api/exceptions.rst @@ -129,7 +129,7 @@ OAuth Error Handling client_id=client_id, client_secret=client_secret ) - token = await credential_helper.get_token() + token = credential_helper.get_token() except OAuthError as e: logger.error(f"OAuth authentication failed: {e}") # Handle OAuth failure - check credentials diff --git a/docs/authentication.rst b/docs/authentication.rst index 066c587..49d147f 100644 --- a/docs/authentication.rst +++ b/docs/authentication.rst @@ -76,14 +76,6 @@ Set these environment variables for OAuth authentication: audience=os.getenv("OAUTH_AUDIENCE", "crisp-athena-live"), ) - # Test token acquisition - try: - token = await credential_helper.get_token() - print(f"Successfully acquired token (length: {len(token)})") - except Exception as e: - print(f"Failed to acquire OAuth token: {e}") - return - # Create authenticated channel channel = await create_channel_with_credentials( host=os.getenv("ATHENA_HOST"), @@ -261,7 +253,7 @@ Handle OAuth-specific errors gracefully: from resolver_athena_client.client.exceptions import AuthenticationError try: - token = await credential_helper.get_token() + token = credential_helper.get_token() except AuthenticationError as e: logger.error(f"OAuth authentication failed: {e}") # Handle authentication failure @@ -356,7 +348,7 @@ Test your authentication setup: client_secret=os.getenv("OAUTH_CLIENT_SECRET"), ) - token = await credential_helper.get_token() + token = credential_helper.get_token() print(f"✓ Authentication successful (token length: {len(token)})") return True diff --git a/examples/classify_single_example.py b/examples/classify_single_example.py index f543bed..d11d0b9 100755 --- a/examples/classify_single_example.py +++ b/examples/classify_single_example.py @@ -214,15 +214,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Configure client options options = AthenaOptions( host=host, diff --git a/examples/example.py b/examples/example.py index e4e8220..4b43d7e 100755 --- a/examples/example.py +++ b/examples/example.py @@ -163,15 +163,6 @@ async def main() -> int: audience=audience, ) - # Test token acquisition - try: - logger.info("Acquiring OAuth token...") - token = await credential_helper.get_token() - logger.info("Successfully acquired token (length: %d)", len(token)) - except Exception: - logger.exception("Failed to acquire OAuth token") - return 1 - # Get available deployment channel = await create_channel_with_credentials(host, credential_helper) async with DeploymentSelector(channel) as deployment_selector: diff --git a/src/resolver_athena_client/client/channel.py b/src/resolver_athena_client/client/channel.py index e9d1510..d8f192c 100644 --- a/src/resolver_athena_client/client/channel.py +++ b/src/resolver_athena_client/client/channel.py @@ -1,9 +1,8 @@ """Channel creation utilities for the Athena client.""" -import asyncio import json +import threading import time -from typing import override import grpc import httpx @@ -16,39 +15,6 @@ ) -class TokenMetadataPlugin(grpc.AuthMetadataPlugin): - """Plugin that adds authorization token to gRPC metadata.""" - - def __init__(self, token: str) -> None: - """Initialize the plugin with the auth token. - - Args: - ---- - token: The authorization token to add to requests - - """ - self._token: str = token - - @override - def __call__( - self, - _: grpc.AuthMetadataContext, - callback: grpc.AuthMetadataPluginCallback, - ) -> None: - """Pass authentication metadata to the provided callback. - - This method will be invoked asynchronously in a separate thread. - - Args: - ---- - callback: An AuthMetadataPluginCallback to be invoked either - synchronously or asynchronously. - - """ - metadata = (("authorization", f"Token {self._token}"),) - callback(metadata, None) - - class CredentialHelper: """OAuth credential helper for managing authentication tokens.""" @@ -82,9 +48,9 @@ def __init__( self._audience: str = audience self._token: str | None = None self._token_expires_at: float | None = None - self._lock: asyncio.Lock = asyncio.Lock() + self._lock = threading.Lock() - async def get_token(self) -> str: + def get_token(self) -> str: """Get a valid authentication token. This method will return a cached token if it's still valid, @@ -97,20 +63,14 @@ async def get_token(self) -> str: Raises ------ OAuthError: If token acquisition fails - TokenExpiredError: If token has expired and refresh fails """ - async with self._lock: - if self._is_token_valid(): - if self._token is None: - msg = "Token should be valid but is None" - raise RuntimeError(msg) - return self._token - - await self._refresh_token() - if self._token is None: - msg = "Token refresh failed" - raise RuntimeError(msg) + if self._is_token_valid(): + return self._token + + with self._lock: + if not self._is_token_valid(): + self._refresh_token() return self._token def _is_token_valid(self) -> bool: @@ -127,7 +87,7 @@ def _is_token_valid(self) -> bool: # Add 30 second buffer before expiration return time.time() < (self._token_expires_at - 30) - async def _refresh_token(self) -> None: + def _refresh_token(self) -> None: """Refresh the authentication token by making an OAuth request. Raises @@ -145,8 +105,8 @@ async def _refresh_token(self) -> None: headers = {"content-type": "application/json"} try: - async with httpx.AsyncClient() as client: - response = await client.post( + with httpx.Client() as client: + response = client.post( self._auth_url, json=payload, headers=headers, @@ -154,12 +114,10 @@ async def _refresh_token(self) -> None: ) _ = response.raise_for_status() - token_data = response.json() - self._token = token_data["access_token"] - expires_in = token_data.get( - "expires_in", 3600 - ) # Default 1 hour - self._token_expires_at = time.time() + expires_in + token_data = response.json() + self._token = token_data["access_token"] + expires_in = token_data.get("expires_in", 3600) # Default 1 hour + self._token_expires_at = time.time() + expires_in except httpx.HTTPStatusError as e: error_detail = "" @@ -190,13 +148,29 @@ async def _refresh_token(self) -> None: msg = f"Unexpected error during OAuth: {e}" raise OAuthError(msg) from e - async def invalidate_token(self) -> None: + def invalidate_token(self) -> None: """Invalidate the current token to force a refresh on next use.""" - async with self._lock: + with self._lock: self._token = None self._token_expires_at = None +class _AutoRefreshTokenAuthMetadataPlugin(grpc.AuthMetadataPlugin): + def __init__(self, credential_helper: CredentialHelper) -> None: + self._credential_helper = credential_helper + + def __call__( + self, + _: grpc.AuthMetadataContext, + callback: grpc.AuthMetadataPluginCallback, + ) -> None: + try: + token = self._credential_helper.get_token() + metadata = (("authorization", f"Bearer {token}"),) + callback(metadata, None) + except OAuthError as err: + callback(None, err) + async def create_channel_with_credentials( host: str, credential_helper: CredentialHelper, @@ -221,13 +195,10 @@ async def create_channel_with_credentials( if not host: raise InvalidHostError(InvalidHostError.default_message) - # Get a valid token from the credential helper - token = await credential_helper.get_token() - # Create credentials with token authentication credentials = grpc.composite_channel_credentials( grpc.ssl_channel_credentials(), - grpc.access_token_call_credentials(token), + grpc.metadata_call_credentials(_AutoRefreshTokenAuthMetadataPlugin(credential_helper)), ) # Configure gRPC options for persistent connections diff --git a/tests/client/test_channel.py b/tests/client/test_channel.py index 701bd00..c0b2dac 100644 --- a/tests/client/test_channel.py +++ b/tests/client/test_channel.py @@ -7,11 +7,9 @@ import httpx import pytest -from grpc.aio import Channel from resolver_athena_client.client.channel import ( CredentialHelper, - TokenMetadataPlugin, create_channel_with_credentials, ) from resolver_athena_client.client.exceptions import ( @@ -21,23 +19,6 @@ ) -def test_token_metadata_plugin() -> None: - """Test TokenMetadataPlugin functionality.""" - test_token = "test-token" - plugin = TokenMetadataPlugin(test_token) - - # Mock callback - mock_callback = mock.Mock() - mock_context = mock.Mock() - - # Call the plugin - plugin(mock_context, mock_callback) - - # Verify the callback was called with correct metadata - expected_metadata = (("authorization", f"Token {test_token}"),) - mock_callback.assert_called_once_with(expected_metadata, None) - - @pytest.mark.asyncio async def test_create_channel_with_credentials_validation() -> None: """Test channel creation with credentials validates input properly.""" @@ -49,18 +30,6 @@ async def test_create_channel_with_credentials_validation() -> None: _ = await create_channel_with_credentials(test_host, mock_helper) -@pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_failure() -> None: - """Test channel creation when OAuth token acquisition fails.""" - test_host = "test-host:50051" - - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("Token acquisition failed") - - with pytest.raises(OAuthError, match="Token acquisition failed"): - _ = await create_channel_with_credentials(test_host, mock_helper) - - class TestCredentialHelper: """Test cases for CredentialHelper OAuth functionality.""" @@ -153,8 +122,7 @@ def test_is_token_valid_with_soon_expiring_token(self) -> None: assert not helper._is_token_valid() - @pytest.mark.asyncio - async def test_get_token_success(self) -> None: + def test_get_token_success(self) -> None: """Test successful token acquisition.""" helper = CredentialHelper( client_id="test_client_id", @@ -168,18 +136,17 @@ async def test_get_token_success(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response - token = await helper.get_token() + token = helper.get_token() assert token == "new_access_token" assert helper._token == "new_access_token" assert helper._token_expires_at is not None - @pytest.mark.asyncio - async def test_get_token_cached(self) -> None: + def test_get_token_cached(self) -> None: """Test that cached token is returned when valid.""" helper = CredentialHelper( client_id="test_client_id", @@ -190,12 +157,11 @@ async def test_get_token_cached(self) -> None: helper._token = "cached_token" helper._token_expires_at = time.time() + 3600 - token = await helper.get_token() + token = helper.get_token() assert token == "cached_token" - @pytest.mark.asyncio - async def test_refresh_token_http_error(self) -> None: + def test_refresh_token_http_error(self) -> None: """Test token refresh with HTTP error.""" helper = CredentialHelper( client_id="test_client_id", @@ -215,17 +181,16 @@ async def test_refresh_token_http_error(self) -> None: response=mock_response, ) - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = http_error with pytest.raises( OAuthError, match="OAuth request failed with status 401" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_request_error(self) -> None: + def test_refresh_token_request_error(self) -> None: """Test token refresh with request error.""" helper = CredentialHelper( client_id="test_client_id", @@ -234,17 +199,16 @@ async def test_refresh_token_request_error(self) -> None: request_error = httpx.RequestError("Connection failed") - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.side_effect = request_error with pytest.raises( OAuthError, match="Failed to connect to OAuth server" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_refresh_token_invalid_response(self) -> None: + def test_refresh_token_invalid_response(self) -> None: """Test token refresh with invalid response format.""" helper = CredentialHelper( client_id="test_client_id", @@ -257,17 +221,16 @@ async def test_refresh_token_invalid_response(self) -> None: } mock_response.raise_for_status.return_value = None - with mock.patch("httpx.AsyncClient") as mock_client: - mock_response_obj = mock_client.return_value.__aenter__.return_value + with mock.patch("httpx.Client") as mock_client: + mock_response_obj = mock_client.return_value.__enter__.return_value mock_response_obj.post.return_value = mock_response with pytest.raises( OAuthError, match="Invalid OAuth response format" ): - _ = await helper.get_token() + _ = helper.get_token() - @pytest.mark.asyncio - async def test_invalidate_token(self) -> None: + def test_invalidate_token(self) -> None: """Test token invalidation.""" helper = CredentialHelper( client_id="test_client_id", @@ -278,45 +241,12 @@ async def test_invalidate_token(self) -> None: helper._token = "valid_token" helper._token_expires_at = time.time() + 3600 - await helper.invalidate_token() + helper.invalidate_token() assert helper._token is None assert helper._token_expires_at is None -@pytest.mark.asyncio -async def test_create_channel_with_credentials_success() -> None: - """Test successful channel creation with credential helper.""" - test_host = "test-host:50051" - - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.return_value = "test_token" - - mock_credentials = mock.Mock() - mock_channel = mock.Mock(spec=Channel) - - with ( - mock.patch("grpc.ssl_channel_credentials") as mock_ssl_creds, - mock.patch("grpc.access_token_call_credentials") as mock_token_creds, - mock.patch( - "grpc.composite_channel_credentials" - ) as mock_composite_creds, - mock.patch("grpc.aio.secure_channel") as mock_secure_channel, - ): - # Set up mocks - mock_ssl_creds.return_value = mock.Mock() - mock_token_creds.return_value = mock.Mock() - mock_composite_creds.return_value = mock_credentials - mock_secure_channel.return_value = mock_channel - - # Create channel - channel = await create_channel_with_credentials(test_host, mock_helper) - - # Verify channel creation - assert channel == mock_channel - mock_helper.get_token.assert_called_once() - mock_token_creds.assert_called_once_with("test_token") - @pytest.mark.asyncio async def test_create_channel_with_credentials_invalid_host() -> None: @@ -328,14 +258,3 @@ async def test_create_channel_with_credentials_invalid_host() -> None: with pytest.raises(InvalidHostError, match="host cannot be empty"): _ = await create_channel_with_credentials(test_host, mock_helper) - -@pytest.mark.asyncio -async def test_create_channel_with_credentials_oauth_error() -> None: - """Test channel creation with credentials when OAuth fails.""" - test_host = "test-host:50051" - - mock_helper = mock.Mock(spec=CredentialHelper) - mock_helper.get_token.side_effect = OAuthError("OAuth failed") - - with pytest.raises(OAuthError, match="OAuth failed"): - _ = await create_channel_with_credentials(test_host, mock_helper) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index e7cd324..6690b40 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -46,22 +46,13 @@ async def credential_helper() -> CredentialHelper: audience = os.getenv("OAUTH_AUDIENCE", "crisp-athena-live") # Create credential helper - credential_helper = CredentialHelper( + return CredentialHelper( client_id=client_id, client_secret=client_secret, auth_url=auth_url, audience=audience, ) - # Test token acquisition - try: - _ = await credential_helper.get_token() - except Exception as e: - msg = "Failed to acquire OAuth token" - raise AssertionError(msg) from e - - return credential_helper - @pytest.fixture def athena_options() -> AthenaOptions: