From 7b4d0edfffc6da374d3c431d0b0a9a06c7fe444e Mon Sep 17 00:00:00 2001 From: Christian Wygoda Date: Tue, 21 Oct 2025 10:35:02 +0200 Subject: [PATCH 1/2] fix: add grace period when checking access_token expiration --- satellitevu/auth/auth.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/satellitevu/auth/auth.py b/satellitevu/auth/auth.py index 347fbaa..b7653af 100644 --- a/satellitevu/auth/auth.py +++ b/satellitevu/auth/auth.py @@ -1,5 +1,5 @@ from base64 import b64decode -from datetime import datetime +from datetime import datetime, timedelta from hashlib import sha1 from json import loads from logging import getLogger @@ -15,14 +15,27 @@ logger = getLogger(__file__) -def is_expired_token(token: str) -> bool: +def is_expired_token(token: str, grace_period_seconds: int = 0) -> bool: + """ + Check if a JWT token is expired. + + Args: + token: JWT token string + grace_period_seconds: Number of seconds before actual expiration to consider token expired. + This provides a safety buffer for network latency and clock skew. + Default is 0 (no grace period). + + Returns: + True if token is expired (or within grace period of expiration), False otherwise + """ json = b64decode(token.split(".")[1] + "==") claims = loads(json) if not claims or "exp" not in claims: return False exp = float(claims["exp"]) exp_dt = datetime.fromtimestamp(exp) - return exp_dt <= datetime.now() + grace_period = timedelta(seconds=grace_period_seconds) + return exp_dt <= datetime.now() + grace_period class Auth: @@ -34,6 +47,7 @@ class Auth: auth_url: str client: AbstractClient + grace_period_seconds: int def __init__( self, @@ -44,7 +58,22 @@ def __init__( cache: Optional[AbstractCache] = None, auth_url: Optional[str] = None, client: Optional[AbstractClient] = None, + grace_period_seconds: int = 60, ): + """ + Initialize Auth handler for OAuth2 client credentials flow. + + Args: + client_id: OAuth2 client ID + client_secret: OAuth2 client secret + audience: OAuth2 audience (defaults to API gateway URL) + cache: Cache implementation for storing tokens (defaults to AppDirCache) + auth_url: Authentication server base URL + client: HTTP client implementation (defaults to UrllibClient) + grace_period_seconds: Number of seconds before token expiration to consider it expired. + This provides a safety buffer for network latency and clock skew. + Default is 60 seconds. + """ self.client_id = client_id self.client_secret = client_secret self.audience = audience or AUDIENCE @@ -52,6 +81,7 @@ def __init__( self.cache = cache or AppDirCache() self.auth_url = auth_url or AUTH_URL self.client = client or UrllibClient() + self.grace_period_seconds = grace_period_seconds def token(self, scopes: Optional[List] = None) -> str: if not scopes: @@ -61,7 +91,7 @@ def token(self, scopes: Optional[List] = None) -> str: token = self.cache.load(cache_key.hexdigest()) - if not token or is_expired_token(token): + if not token or is_expired_token(token, self.grace_period_seconds): token = self._auth(scopes) self.cache.save(cache_key.hexdigest(), token) From be0f9ecea8d1d12ec460ddbdbefa4bdde10e966b Mon Sep 17 00:00:00 2001 From: Christian Wygoda Date: Tue, 21 Oct 2025 15:08:02 +0200 Subject: [PATCH 2/2] feat: support short lived access tokens, with opt-in refresh token flow --- satellitevu/apis/base.py | 79 +++++++++++++++++---- satellitevu/auth/auth.py | 143 +++++++++++++++++++++++++++++++++++--- satellitevu/auth/cache.py | 83 ++++++++++++++++++++-- 3 files changed, 278 insertions(+), 27 deletions(-) diff --git a/satellitevu/apis/base.py b/satellitevu/apis/base.py index 6228a5c..13ae733 100644 --- a/satellitevu/apis/base.py +++ b/satellitevu/apis/base.py @@ -11,10 +11,24 @@ class AbstractApi(ABC): base_url: str api_path: str scopes = [] + max_auth_retries: int - def __init__(self, client: AbstractClient, base_url: str): + def __init__( + self, client: AbstractClient, base_url: str, max_auth_retries: int = 1 + ): + """ + Initialize API client. + + Args: + client: HTTP client implementation + base_url: Base URL for API endpoints + max_auth_retries: Maximum number of times to retry a request after 401 error. + When a 401 is received, the cached token is invalidated and + the request is retried with a fresh token. Default is 1. + """ self.client = client self.base_url = base_url + self.max_auth_retries = max_auth_retries def url(self, path: str) -> str: api_base_url = urljoin(self.base_url, self.api_path.lstrip("/")) @@ -23,21 +37,62 @@ def url(self, path: str) -> str: return urljoin(api_base_url, path.lstrip("/")) def make_request(self, *args, **kwargs): + """ + Make an HTTP request with automatic 401 retry. + + When a 401 Unauthorized response is received, this method will: + 1. Invalidate the cached authentication token + 2. Retry the request (up to max_auth_retries times) + 3. The retry will automatically fetch a fresh token via refresh or re-auth + + Args: + *args: Positional arguments passed to client.request() + **kwargs: Keyword arguments passed to client.request() + + Returns: + Response object from the HTTP client + + Raises: + Api401Error: If request is unauthorized after all retries + Api403Error: If request is forbidden (no retry) + """ if "scopes" not in kwargs: kwargs["scopes"] = self.scopes - response = self.client.request(*args, **kwargs) - - if response.status == 401: - raise Api401Error("Unauthorized to make this request.") - elif response.status == 403: - raise Api403Error( - ( - "Not permitted to perform this action. " - "Please contact Satellite Vu for assistance." + + for attempt in range(self.max_auth_retries + 1): + response = self.client.request(*args, **kwargs) + + if response.status == 401: + if attempt < self.max_auth_retries: + self._invalidate_auth_token(kwargs.get("scopes")) + continue + else: + raise Api401Error("Unauthorized to make this request after retry.") + elif response.status == 403: + raise Api403Error( + ( + "Not permitted to perform this action. " + "Please contact Satellite Vu for assistance." + ) ) - ) - return response + return response + + def _invalidate_auth_token(self, scopes): + """ + Invalidate the cached authentication token. + + This forces the next request to fetch a fresh token. + + Args: + scopes: Scopes associated with the token to invalidate + """ + auth = next( + (v for k, v in self.client._auth.items() if self.base_url.startswith(k)), + None, + ) + if auth: + auth.invalidate_token(scopes) def deprecation_warning(self, new_cls): simplefilter("always", DeprecationWarning) diff --git a/satellitevu/auth/auth.py b/satellitevu/auth/auth.py index b7653af..fa40081 100644 --- a/satellitevu/auth/auth.py +++ b/satellitevu/auth/auth.py @@ -48,6 +48,7 @@ class Auth: auth_url: str client: AbstractClient grace_period_seconds: int + enable_refresh: bool def __init__( self, @@ -59,6 +60,7 @@ def __init__( auth_url: Optional[str] = None, client: Optional[AbstractClient] = None, grace_period_seconds: int = 60, + enable_refresh: bool = False, ): """ Initialize Auth handler for OAuth2 client credentials flow. @@ -73,6 +75,9 @@ def __init__( grace_period_seconds: Number of seconds before token expiration to consider it expired. This provides a safety buffer for network latency and clock skew. Default is 60 seconds. + enable_refresh: Enable refresh token flow. When enabled, the client will request + refresh tokens and use them to obtain new access tokens instead of + performing full re-authentication. Default is True. """ self.client_id = client_id self.client_secret = client_secret @@ -82,26 +87,101 @@ def __init__( self.auth_url = auth_url or AUTH_URL self.client = client or UrllibClient() self.grace_period_seconds = grace_period_seconds + self.enable_refresh = enable_refresh - def token(self, scopes: Optional[List] = None) -> str: + def invalidate_token(self, scopes: Optional[List] = None): + """ + Invalidate cached token for given scopes. + + This forces the next token() call to fetch a fresh token, either via + refresh token or full authentication. + + Args: + scopes: Optional list of scopes whose cached token should be invalidated + """ if not scopes: scopes = [] cache_key = sha1(self.client_id.encode("utf-8")) # nosec B324 cache_key.update("".join(scopes).encode("utf-8")) + cache_key_digest = cache_key.hexdigest() + + # Save None to invalidate the access_token cache entry + self.cache.save(cache_key_digest, "", None) + logger.info("Invalidated cached token") - token = self.cache.load(cache_key.hexdigest()) + def token(self, scopes: Optional[List] = None) -> str: + """ + Get a valid access token, using cache or refresh token if possible. + + Args: + scopes: Optional list of OAuth2 scopes to request + + Returns: + Valid access token + + Raises: + AuthError: If authentication or token refresh fails + """ + if not scopes: + scopes = [] + cache_key = sha1(self.client_id.encode("utf-8")) # nosec B324 + cache_key.update("".join(scopes).encode("utf-8")) + cache_key_digest = cache_key.hexdigest() + + # Try to load both access and refresh tokens from cache + access_token, refresh_token = self.cache.load_tokens(cache_key_digest) + + # Check if access token is valid + if access_token and not is_expired_token( + access_token, self.grace_period_seconds + ): + return access_token + + # Access token is expired or missing, try to refresh if we have a refresh token + if self.enable_refresh and refresh_token: + # Check if refresh token itself is expired (no grace period for refresh tokens) + if not is_expired_token(refresh_token, 0): + try: + logger.info("Access token expired, attempting refresh") + new_access_token = self._refresh(refresh_token, scopes) + # Keep the same refresh token (it doesn't change) + self.cache.save(cache_key_digest, new_access_token, refresh_token) + return new_access_token + except AuthError as e: + # Refresh failed, fall back to full authentication + logger.warning( + f"Token refresh failed: {e}. Falling back to full authentication" + ) + + # No refresh token, refresh failed, or refresh disabled - perform full authentication + access_token, refresh_token = self._auth(scopes) + self.cache.save(cache_key_digest, access_token, refresh_token) + return access_token + + def _auth(self, scopes: Optional[List] = None) -> tuple[str, Optional[str]]: + """ + Perform OAuth2 client credentials authentication. - if not token or is_expired_token(token, self.grace_period_seconds): - token = self._auth(scopes) - self.cache.save(cache_key.hexdigest(), token) + Args: + scopes: Optional list of scopes to request - return token + Returns: + Tuple of (access_token, refresh_token). refresh_token may be None if not supported + by the server or if refresh tokens are disabled. - def _auth(self, scopes: Optional[List] = None) -> str: + Raises: + AuthError: If authentication request fails + """ if not scopes: scopes = [] logger.info("Performing client_credential authentication") token_url = urljoin(self.auth_url, "oauth/token") + + # Add offline_access scope if refresh tokens are enabled + request_scopes = scopes.copy() + if self.enable_refresh and "offline_access" not in request_scopes: + request_scopes.append("offline_access") + response = self.client.post( token_url, headers={"content-type": "application/x-www-form-urlencoded"}, @@ -110,7 +190,7 @@ def _auth(self, scopes: Optional[List] = None) -> str: "client_id": self.client_id, "client_secret": self.client_secret, "audience": self.audience, - "scope": " ".join(scopes), + "scope": " ".join(request_scopes), }, ) @@ -121,8 +201,53 @@ def _auth(self, scopes: Optional[List] = None) -> str: ) try: payload = response.json() - return payload["access_token"] + access_token = payload["access_token"] + refresh_token = payload.get("refresh_token") + return (access_token, refresh_token) except Exception: raise AuthError( "Unexpected response body for client_credential flow: " + response.text ) + + def _refresh(self, refresh_token: str, scopes: Optional[List] = None) -> str: + """ + Refresh access token using a refresh token. + + Args: + refresh_token: Valid refresh token + scopes: Optional list of scopes to request + + Returns: + New access token + + Raises: + AuthError: If refresh request fails + """ + if not scopes: + scopes = [] + logger.info("Refreshing access token using refresh token") + token_url = urljoin(self.auth_url, "oauth/token") + response = self.client.post( + token_url, + headers={"content-type": "application/x-www-form-urlencoded"}, + data={ + "grant_type": "refresh_token", + "client_id": self.client_id, + "client_secret": self.client_secret, + "refresh_token": refresh_token, + "scope": " ".join(scopes), + }, + ) + + if response.status != 200: + raise AuthError( + "Unexpected error code for refresh_token flow: " + f"{response.status} - {response.text}" + ) + try: + payload = response.json() + return payload["access_token"] + except Exception: + raise AuthError( + "Unexpected response body for refresh_token flow: " + response.text + ) diff --git a/satellitevu/auth/cache.py b/satellitevu/auth/cache.py index b08be6b..fc05bc2 100644 --- a/satellitevu/auth/cache.py +++ b/satellitevu/auth/cache.py @@ -11,26 +11,81 @@ class AbstractCache(ABC): """ - Abstract cache interface, implemented by actual cache classes + Abstract cache interface, implemented by actual cache classes. + + Supports storing both access tokens and optional refresh tokens. """ @abstractmethod - def save(self, client_id: str, value: str): + def save(self, client_id: str, value: str, refresh_token: Optional[str] = None): + """ + Save access token and optional refresh token to cache. + + Args: + client_id: Cache key (typically hash of client_id + scopes) + value: Access token to cache + refresh_token: Optional refresh token to cache alongside access token + """ pass @abstractmethod def load(self, client_id: str) -> Optional[str]: + """ + Load access token from cache. + + Args: + client_id: Cache key (typically hash of client_id + scopes) + + Returns: + Access token if found, None otherwise + + Note: This method is kept for backward compatibility. Use load_tokens() + to retrieve both access and refresh tokens. + """ pass + def load_tokens(self, client_id: str) -> tuple[Optional[str], Optional[str]]: + """ + Load both access token and refresh token from cache. + + Args: + client_id: Cache key (typically hash of client_id + scopes) + + Returns: + Tuple of (access_token, refresh_token). Either or both may be None. + """ + # Default implementation for backward compatibility + # Subclasses should override if they support refresh tokens + access_token = self.load(client_id) + return (access_token, None) + class MemoryCache(AbstractCache): _items = {} - def save(self, client_id: str, value: str): - self._items[client_id] = value + def save(self, client_id: str, value: str, refresh_token: Optional[str] = None): + self._items[client_id] = { + "access_token": value, + "refresh_token": refresh_token, + } def load(self, client_id: str) -> Optional[str]: - return self._items.get(client_id) + item = self._items.get(client_id) + if item is None: + return None + # Handle both old format (string) and new format (dict) + if isinstance(item, str): + return item + return item.get("access_token") + + def load_tokens(self, client_id: str) -> tuple[Optional[str], Optional[str]]: + item = self._items.get(client_id) + if item is None: + return (None, None) + # Handle both old format (string) and new format (dict) + if isinstance(item, str): + return (item, None) + return (item.get("access_token"), item.get("refresh_token")) class AppDirCache(AbstractCache): @@ -48,7 +103,7 @@ def __init__(self, cache_dir: Optional[str] = None): if not os.path.exists(self.cache_dir): os.makedirs(self.cache_dir) - def save(self, client_id: str, value: str): + def save(self, client_id: str, value: str, refresh_token: Optional[str] = None): parser = ConfigParser() parser.read(self.cache_file) @@ -57,6 +112,11 @@ def save(self, client_id: str, value: str): except DuplicateSectionError: pass parser[client_id]["access_token"] = value + if refresh_token: + parser[client_id]["refresh_token"] = refresh_token + elif "refresh_token" in parser[client_id]: + # Remove refresh_token if None is explicitly passed + del parser[client_id]["refresh_token"] with NamedTemporaryFile("w", dir=str(self.cache_dir), delete=False) as handle: parser.write(handle) @@ -70,3 +130,14 @@ def load(self, client_id: str) -> Optional[str]: return parser[client_id]["access_token"] except (FileNotFoundError, KeyError): return None + + def load_tokens(self, client_id: str) -> tuple[Optional[str], Optional[str]]: + try: + parser = ConfigParser() + parser.read(self.cache_file) + + access_token = parser[client_id].get("access_token") + refresh_token = parser[client_id].get("refresh_token") + return (access_token, refresh_token) + except (FileNotFoundError, KeyError): + return (None, None)