diff --git a/src/cvec/cvec.py b/src/cvec/cvec.py index f5e9fd9..534e396 100644 --- a/src/cvec/cvec.py +++ b/src/cvec/cvec.py @@ -1,3 +1,4 @@ +import logging import os from datetime import datetime from typing import Any, Dict, List, Optional @@ -12,6 +13,8 @@ metric_data_points_to_arrow, ) +logger = logging.getLogger(__name__) + class CVec: """ @@ -91,6 +94,7 @@ def _get_headers(self) -> Dict[str, str]: return { "Authorization": f"Bearer {self._access_token}", "Content-Type": "application/json", + "Accept": "application/json", } def _make_request( @@ -117,7 +121,6 @@ def _make_request( data=data, ) - # If we get a 401 and we have Supabase tokens, try to refresh and retry if response.status_code == 401 and self._access_token and self._refresh_token: try: self._refresh_supabase_token() @@ -135,9 +138,14 @@ def _make_request( json=json, data=data, ) - except Exception: - print("Token refresh failed") - # If refresh fails, continue with the original error + except (requests.RequestException, ValueError, KeyError) as e: + logger.warning( + "Token refresh failed, continuing with original request: %s", + e, + exc_info=True, + ) + # If refresh fails, continue with the original error response + # which will be raised by raise_for_status() below pass response.raise_for_status() @@ -410,6 +418,7 @@ def _login_with_supabase(self, email: str, password: str) -> None: headers = { "Content-Type": "application/json", + "Accept": "application/json", "apikey": self._publishable_key, } @@ -434,6 +443,7 @@ def _refresh_supabase_token(self) -> None: headers = { "Content-Type": "application/json", + "Accept": "application/json", "apikey": self._publishable_key, } diff --git a/tests/test_token_refresh.py b/tests/test_token_refresh.py new file mode 100644 index 0000000..0c9d780 --- /dev/null +++ b/tests/test_token_refresh.py @@ -0,0 +1,136 @@ +"""Tests for token refresh functionality.""" + +import pytest +import requests +from typing import Any +from unittest.mock import Mock, patch + +from cvec import CVec + + +class TestTokenRefresh: + """Test cases for automatic token refresh functionality.""" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_on_401( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that token refresh is triggered on 401 Unauthorized.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + # Mock response sequence + mock_response_401 = Mock() + mock_response_401.status_code = 401 + + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-type": "application/json"} + mock_response_success.json.return_value = [] + + mock_request.side_effect = [ + mock_response_401, + mock_response_success, + ] + + # Mock refresh method + refresh_called: list[bool] = [] + + def mock_refresh() -> None: + refresh_called.append(True) + client._access_token = "new_token" + + with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): + # Execute request + result = client.get_metrics() + + # Verify refresh was called + assert len(refresh_called) == 1 + assert client._access_token == "new_token" + assert result == [] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_handles_network_errors_gracefully( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that network errors during refresh don't crash, returns original error.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + # Mock response: 401 triggers refresh + mock_response_401 = Mock() + mock_response_401.status_code = 401 + mock_response_401.raise_for_status.side_effect = requests.HTTPError( + "401 Client Error: Unauthorized" + ) + + mock_request.return_value = mock_response_401 + + # Mock refresh to raise network error + def mock_refresh_with_error() -> None: + raise requests.ConnectionError("Network unreachable") + + with patch.object( + client, "_refresh_supabase_token", side_effect=mock_refresh_with_error + ): + # Should not crash, should raise the original 401 error + with pytest.raises(requests.HTTPError, match="401"): + client.get_metrics() + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_handles_missing_refresh_token( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that missing refresh token is handled gracefully.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + # Mock response: 401 triggers refresh + mock_response_401 = Mock() + mock_response_401.status_code = 401 + mock_response_401.raise_for_status.side_effect = requests.HTTPError( + "401 Client Error: Unauthorized" + ) + + mock_request.return_value = mock_response_401 + + # Mock refresh to raise ValueError (missing refresh token) + def mock_refresh_with_error() -> None: + raise ValueError("No refresh token available") + + with patch.object( + client, "_refresh_supabase_token", side_effect=mock_refresh_with_error + ): + # Should not crash, should raise the original 401 error + with pytest.raises(requests.HTTPError, match="401"): + client.get_metrics()