From fcb170164a64f0e2c151bbc7668d0c88dd61aa28 Mon Sep 17 00:00:00 2001 From: alok27a Date: Thu, 23 Oct 2025 12:32:30 -0400 Subject: [PATCH 1/4] Modifying token refresh logic --- src/cvec/cvec.py | 28 ++++- tests/test_token_refresh.py | 201 ++++++++++++++++++++++++++++++++++++ 2 files changed, 225 insertions(+), 4 deletions(-) create mode 100644 tests/test_token_refresh.py diff --git a/src/cvec/cvec.py b/src/cvec/cvec.py index f5e9fd9..f1fbd91 100644 --- a/src/cvec/cvec.py +++ b/src/cvec/cvec.py @@ -115,10 +115,18 @@ def _make_request( params=params, json=json, data=data, + allow_redirects=False, # Disable auto-redirect to catch auth redirects ) - # If we get a 401 and we have Supabase tokens, try to refresh and retry - if response.status_code == 401 and self._access_token and self._refresh_token: + needs_refresh = False + if response.status_code == 401: + needs_refresh = True + elif response.status_code in (301, 302, 303, 307, 308): + location = response.headers.get("Location", "") + if "login" in location.lower() or "token" in location.lower(): + needs_refresh = True + + if needs_refresh and self._access_token and self._refresh_token: try: self._refresh_supabase_token() # Update headers with new token @@ -134,12 +142,24 @@ def _make_request( params=params, json=json, data=data, + allow_redirects=False, ) - except Exception: - print("Token refresh failed") + except Exception as e: + print(f"Token refresh failed: {e}") # If refresh fails, continue with the original error pass + if response.status_code in (301, 302, 303, 307, 308): + response = requests.request( + method="GET" if response.status_code == 303 else method, + url=urljoin(url, response.headers.get("Location", "")), + headers=request_headers, + params=params if method == "GET" else None, + json=json if method != "GET" and response.status_code != 303 else None, + data=data if method != "GET" and response.status_code != 303 else None, + allow_redirects=True, + ) + response.raise_for_status() if ( diff --git a/tests/test_token_refresh.py b/tests/test_token_refresh.py new file mode 100644 index 0000000..d3dc2f0 --- /dev/null +++ b/tests/test_token_refresh.py @@ -0,0 +1,201 @@ +"""Tests for token refresh functionality.""" +import pytest +from unittest.mock import Mock, patch +from cvec import CVec +from typing import Any + + +class TestTokenRefresh: + """Test cases for automatic token refresh functionality.""" + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_on_401( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that token refresh is triggered on 401 Unauthorized.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + # Mock response sequence + mock_response_401 = Mock() + mock_response_401.status_code = 401 + + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-type": "application/json"} + mock_response_success.json.return_value = [] + + mock_request.side_effect = [ + mock_response_401, + mock_response_success, + ] + + # Mock refresh method + refresh_called = [] + + def mock_refresh() -> None: + refresh_called.append(True) + client._access_token = "new_token" + + client._refresh_supabase_token = mock_refresh + + # Execute request + result = client.get_metrics() + + # Verify refresh was called + assert len(refresh_called) == 1 + assert client._access_token == "new_token" + assert result == [] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_on_redirect_to_login( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that token refresh is triggered on redirect to login page.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + # Mock response sequence: 307 redirect to login + mock_response_redirect = Mock() + mock_response_redirect.status_code = 307 + mock_response_redirect.headers = { + "Location": "/login?error=Token%20has%20expired" + } + + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-type": "application/json"} + mock_response_success.json.return_value = [] + + mock_request.side_effect = [ + mock_response_redirect, + mock_response_success, + ] + + # Mock refresh method + refresh_called = [] + + def mock_refresh() -> None: + refresh_called.append(True) + client._access_token = "new_token" + + client._refresh_supabase_token = mock_refresh + + # Execute request + result = client.get_metrics() + + # Verify refresh was called + assert len(refresh_called) == 1 + assert client._access_token == "new_token" + assert result == [] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_no_refresh_on_normal_redirect( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that token refresh is NOT triggered on normal redirects.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "valid_token" + client._refresh_token = "valid_refresh_token" + + mock_response_redirect = Mock() + mock_response_redirect.status_code = 302 + mock_response_redirect.headers = {"Location": "/api/v2/metrics"} + + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-type": "application/json"} + mock_response_success.json.return_value = [] + + mock_request.side_effect = [ + mock_response_redirect, + mock_response_success, + ] + + refresh_called = [] + + def mock_refresh() -> None: + refresh_called.append(True) + + client._refresh_supabase_token = mock_refresh + + result = client.get_metrics() + + assert len(refresh_called) == 0 + assert client._access_token == "valid_token" + assert result == [] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_refresh_on_302_redirect_with_token_keyword( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that token refresh works on 302 redirect with 'token' in URL.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + mock_response_redirect = Mock() + mock_response_redirect.status_code = 302 + mock_response_redirect.headers = {"Location": "/auth?error=invalid_token"} + + mock_response_success = Mock() + mock_response_success.status_code = 200 + mock_response_success.headers = {"content-type": "application/json"} + mock_response_success.json.return_value = [] + + mock_request.side_effect = [ + mock_response_redirect, + mock_response_success, + ] + + refresh_called = [] + + def mock_refresh() -> None: + refresh_called.append(True) + client._access_token = "new_token" + + client._refresh_supabase_token = mock_refresh + + result = client.get_metrics() + + assert len(refresh_called) == 1 + assert client._access_token == "new_token" + assert result == [] From acda3a2b9482af425bf820722eb9e50e3852ee08 Mon Sep 17 00:00:00 2001 From: alok27a Date: Thu, 23 Oct 2025 12:39:06 -0400 Subject: [PATCH 2/4] Fixing CI --- tests/test_token_refresh.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_token_refresh.py b/tests/test_token_refresh.py index d3dc2f0..48ec9ed 100644 --- a/tests/test_token_refresh.py +++ b/tests/test_token_refresh.py @@ -1,5 +1,4 @@ """Tests for token refresh functionality.""" -import pytest from unittest.mock import Mock, patch from cvec import CVec from typing import Any From 016cecbdc5b170066a3d3cb82d934f692686c74b Mon Sep 17 00:00:00 2001 From: alok27a Date: Thu, 23 Oct 2025 12:52:19 -0400 Subject: [PATCH 3/4] Fixing CI --- src/cvec/cvec.py | 14 ++++- tests/test_token_refresh.py | 108 +++++++++++++++++++++++++++++------- 2 files changed, 100 insertions(+), 22 deletions(-) diff --git a/src/cvec/cvec.py b/src/cvec/cvec.py index f1fbd91..5582de6 100644 --- a/src/cvec/cvec.py +++ b/src/cvec/cvec.py @@ -1,3 +1,4 @@ +import logging import os from datetime import datetime from typing import Any, Dict, List, Optional @@ -12,6 +13,8 @@ metric_data_points_to_arrow, ) +logger = logging.getLogger(__name__) + class CVec: """ @@ -144,9 +147,14 @@ def _make_request( data=data, allow_redirects=False, ) - except Exception as e: - print(f"Token refresh failed: {e}") - # If refresh fails, continue with the original error + except (requests.RequestException, ValueError, KeyError) as e: + logger.warning( + "Token refresh failed, continuing with original request: %s", + e, + exc_info=True, + ) + # If refresh fails, continue with the original error response + # which will be raised by raise_for_status() below pass if response.status_code in (301, 302, 303, 307, 308): diff --git a/tests/test_token_refresh.py b/tests/test_token_refresh.py index 48ec9ed..953d3f1 100644 --- a/tests/test_token_refresh.py +++ b/tests/test_token_refresh.py @@ -1,7 +1,11 @@ """Tests for token refresh functionality.""" + +import pytest +import requests +from typing import Any from unittest.mock import Mock, patch + from cvec import CVec -from typing import Any class TestTokenRefresh: @@ -40,16 +44,15 @@ def test_token_refresh_on_401( ] # Mock refresh method - refresh_called = [] + refresh_called: list[bool] = [] def mock_refresh() -> None: refresh_called.append(True) client._access_token = "new_token" - client._refresh_supabase_token = mock_refresh - - # Execute request - result = client.get_metrics() + with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): + # Execute request + result = client.get_metrics() # Verify refresh was called assert len(refresh_called) == 1 @@ -92,16 +95,15 @@ def test_token_refresh_on_redirect_to_login( ] # Mock refresh method - refresh_called = [] + refresh_called: list[bool] = [] def mock_refresh() -> None: refresh_called.append(True) client._access_token = "new_token" - client._refresh_supabase_token = mock_refresh - - # Execute request - result = client.get_metrics() + with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): + # Execute request + result = client.get_metrics() # Verify refresh was called assert len(refresh_called) == 1 @@ -140,14 +142,13 @@ def test_no_refresh_on_normal_redirect( mock_response_success, ] - refresh_called = [] + refresh_called: list[bool] = [] def mock_refresh() -> None: refresh_called.append(True) - client._refresh_supabase_token = mock_refresh - - result = client.get_metrics() + with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): + result = client.get_metrics() assert len(refresh_called) == 0 assert client._access_token == "valid_token" @@ -185,16 +186,85 @@ def test_refresh_on_302_redirect_with_token_keyword( mock_response_success, ] - refresh_called = [] + refresh_called: list[bool] = [] def mock_refresh() -> None: refresh_called.append(True) client._access_token = "new_token" - client._refresh_supabase_token = mock_refresh - - result = client.get_metrics() + with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): + result = client.get_metrics() assert len(refresh_called) == 1 assert client._access_token == "new_token" assert result == [] + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_handles_network_errors_gracefully( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that network errors during refresh don't crash, returns original error.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + mock_response_401 = Mock() + mock_response_401.status_code = 401 + mock_response_401.raise_for_status.side_effect = requests.HTTPError( + "401 Client Error: Unauthorized" + ) + + mock_request.return_value = mock_response_401 + + def mock_refresh_with_error() -> None: + raise requests.ConnectionError("Network unreachable") + + with patch.object( + client, "_refresh_supabase_token", side_effect=mock_refresh_with_error + ): + with pytest.raises(requests.HTTPError, match="401"): + client.get_metrics() + + @patch.object(CVec, "_login_with_supabase", return_value=None) + @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") + @patch("cvec.cvec.requests.request") + def test_token_refresh_handles_missing_refresh_token( + self, + mock_request: Any, + mock_fetch_key: Any, + mock_login: Any, + ) -> None: + """Test that missing refresh token is handled gracefully.""" + client = CVec( + host="https://test.example.com", + api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", + ) + + client._access_token = "expired_token" + client._refresh_token = "valid_refresh_token" + + mock_response_401 = Mock() + mock_response_401.status_code = 401 + mock_response_401.raise_for_status.side_effect = requests.HTTPError( + "401 Client Error: Unauthorized" + ) + + mock_request.return_value = mock_response_401 + + def mock_refresh_with_error() -> None: + raise ValueError("No refresh token available") + + with patch.object( + client, "_refresh_supabase_token", side_effect=mock_refresh_with_error + ): + with pytest.raises(requests.HTTPError, match="401"): + client.get_metrics() From 4aaac21b0b813e0cbe11c21c28ee7577503a703f Mon Sep 17 00:00:00 2001 From: alok27a Date: Thu, 23 Oct 2025 14:01:13 -0400 Subject: [PATCH 4/4] Updated code based on comment --- src/cvec/cvec.py | 26 +------ tests/test_token_refresh.py | 146 ++---------------------------------- 2 files changed, 10 insertions(+), 162 deletions(-) diff --git a/src/cvec/cvec.py b/src/cvec/cvec.py index 5582de6..534e396 100644 --- a/src/cvec/cvec.py +++ b/src/cvec/cvec.py @@ -94,6 +94,7 @@ def _get_headers(self) -> Dict[str, str]: return { "Authorization": f"Bearer {self._access_token}", "Content-Type": "application/json", + "Accept": "application/json", } def _make_request( @@ -118,18 +119,9 @@ def _make_request( params=params, json=json, data=data, - allow_redirects=False, # Disable auto-redirect to catch auth redirects ) - needs_refresh = False - if response.status_code == 401: - needs_refresh = True - elif response.status_code in (301, 302, 303, 307, 308): - location = response.headers.get("Location", "") - if "login" in location.lower() or "token" in location.lower(): - needs_refresh = True - - if needs_refresh and self._access_token and self._refresh_token: + if response.status_code == 401 and self._access_token and self._refresh_token: try: self._refresh_supabase_token() # Update headers with new token @@ -145,7 +137,6 @@ def _make_request( params=params, json=json, data=data, - allow_redirects=False, ) except (requests.RequestException, ValueError, KeyError) as e: logger.warning( @@ -157,17 +148,6 @@ def _make_request( # which will be raised by raise_for_status() below pass - if response.status_code in (301, 302, 303, 307, 308): - response = requests.request( - method="GET" if response.status_code == 303 else method, - url=urljoin(url, response.headers.get("Location", "")), - headers=request_headers, - params=params if method == "GET" else None, - json=json if method != "GET" and response.status_code != 303 else None, - data=data if method != "GET" and response.status_code != 303 else None, - allow_redirects=True, - ) - response.raise_for_status() if ( @@ -438,6 +418,7 @@ def _login_with_supabase(self, email: str, password: str) -> None: headers = { "Content-Type": "application/json", + "Accept": "application/json", "apikey": self._publishable_key, } @@ -462,6 +443,7 @@ def _refresh_supabase_token(self) -> None: headers = { "Content-Type": "application/json", + "Accept": "application/json", "apikey": self._publishable_key, } diff --git a/tests/test_token_refresh.py b/tests/test_token_refresh.py index 953d3f1..0c9d780 100644 --- a/tests/test_token_refresh.py +++ b/tests/test_token_refresh.py @@ -59,146 +59,6 @@ def mock_refresh() -> None: assert client._access_token == "new_token" assert result == [] - @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") - @patch("cvec.cvec.requests.request") - def test_token_refresh_on_redirect_to_login( - self, - mock_request: Any, - mock_fetch_key: Any, - mock_login: Any, - ) -> None: - """Test that token refresh is triggered on redirect to login page.""" - client = CVec( - host="https://test.example.com", - api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", - ) - - client._access_token = "expired_token" - client._refresh_token = "valid_refresh_token" - - # Mock response sequence: 307 redirect to login - mock_response_redirect = Mock() - mock_response_redirect.status_code = 307 - mock_response_redirect.headers = { - "Location": "/login?error=Token%20has%20expired" - } - - mock_response_success = Mock() - mock_response_success.status_code = 200 - mock_response_success.headers = {"content-type": "application/json"} - mock_response_success.json.return_value = [] - - mock_request.side_effect = [ - mock_response_redirect, - mock_response_success, - ] - - # Mock refresh method - refresh_called: list[bool] = [] - - def mock_refresh() -> None: - refresh_called.append(True) - client._access_token = "new_token" - - with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): - # Execute request - result = client.get_metrics() - - # Verify refresh was called - assert len(refresh_called) == 1 - assert client._access_token == "new_token" - assert result == [] - - @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") - @patch("cvec.cvec.requests.request") - def test_no_refresh_on_normal_redirect( - self, - mock_request: Any, - mock_fetch_key: Any, - mock_login: Any, - ) -> None: - """Test that token refresh is NOT triggered on normal redirects.""" - client = CVec( - host="https://test.example.com", - api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", - ) - - client._access_token = "valid_token" - client._refresh_token = "valid_refresh_token" - - mock_response_redirect = Mock() - mock_response_redirect.status_code = 302 - mock_response_redirect.headers = {"Location": "/api/v2/metrics"} - - mock_response_success = Mock() - mock_response_success.status_code = 200 - mock_response_success.headers = {"content-type": "application/json"} - mock_response_success.json.return_value = [] - - mock_request.side_effect = [ - mock_response_redirect, - mock_response_success, - ] - - refresh_called: list[bool] = [] - - def mock_refresh() -> None: - refresh_called.append(True) - - with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): - result = client.get_metrics() - - assert len(refresh_called) == 0 - assert client._access_token == "valid_token" - assert result == [] - - @patch.object(CVec, "_login_with_supabase", return_value=None) - @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") - @patch("cvec.cvec.requests.request") - def test_refresh_on_302_redirect_with_token_keyword( - self, - mock_request: Any, - mock_fetch_key: Any, - mock_login: Any, - ) -> None: - """Test that token refresh works on 302 redirect with 'token' in URL.""" - client = CVec( - host="https://test.example.com", - api_key="cva_hHs0CbkKALxMnxUdI9hanF0TBPvvvr1HjG6O", - ) - - client._access_token = "expired_token" - client._refresh_token = "valid_refresh_token" - - mock_response_redirect = Mock() - mock_response_redirect.status_code = 302 - mock_response_redirect.headers = {"Location": "/auth?error=invalid_token"} - - mock_response_success = Mock() - mock_response_success.status_code = 200 - mock_response_success.headers = {"content-type": "application/json"} - mock_response_success.json.return_value = [] - - mock_request.side_effect = [ - mock_response_redirect, - mock_response_success, - ] - - refresh_called: list[bool] = [] - - def mock_refresh() -> None: - refresh_called.append(True) - client._access_token = "new_token" - - with patch.object(client, "_refresh_supabase_token", side_effect=mock_refresh): - result = client.get_metrics() - - assert len(refresh_called) == 1 - assert client._access_token == "new_token" - assert result == [] - @patch.object(CVec, "_login_with_supabase", return_value=None) @patch.object(CVec, "_fetch_publishable_key", return_value="test_publishable_key") @patch("cvec.cvec.requests.request") @@ -217,6 +77,7 @@ def test_token_refresh_handles_network_errors_gracefully( client._access_token = "expired_token" client._refresh_token = "valid_refresh_token" + # Mock response: 401 triggers refresh mock_response_401 = Mock() mock_response_401.status_code = 401 mock_response_401.raise_for_status.side_effect = requests.HTTPError( @@ -225,12 +86,14 @@ def test_token_refresh_handles_network_errors_gracefully( mock_request.return_value = mock_response_401 + # Mock refresh to raise network error def mock_refresh_with_error() -> None: raise requests.ConnectionError("Network unreachable") with patch.object( client, "_refresh_supabase_token", side_effect=mock_refresh_with_error ): + # Should not crash, should raise the original 401 error with pytest.raises(requests.HTTPError, match="401"): client.get_metrics() @@ -252,6 +115,7 @@ def test_token_refresh_handles_missing_refresh_token( client._access_token = "expired_token" client._refresh_token = "valid_refresh_token" + # Mock response: 401 triggers refresh mock_response_401 = Mock() mock_response_401.status_code = 401 mock_response_401.raise_for_status.side_effect = requests.HTTPError( @@ -260,11 +124,13 @@ def test_token_refresh_handles_missing_refresh_token( mock_request.return_value = mock_response_401 + # Mock refresh to raise ValueError (missing refresh token) def mock_refresh_with_error() -> None: raise ValueError("No refresh token available") with patch.object( client, "_refresh_supabase_token", side_effect=mock_refresh_with_error ): + # Should not crash, should raise the original 401 error with pytest.raises(requests.HTTPError, match="401"): client.get_metrics()