Skip to content
This repository was archived by the owner on Jan 14, 2026. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions satellitevu/apis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,24 @@ class AbstractApi(ABC):
base_url: str
api_path: str
scopes = []
max_auth_retries: int

def __init__(self, client: AbstractClient, base_url: str):
def __init__(
self, client: AbstractClient, base_url: str, max_auth_retries: int = 1
):
"""
Initialize API client.

Args:
client: HTTP client implementation
base_url: Base URL for API endpoints
max_auth_retries: Maximum number of times to retry a request after 401 error.
When a 401 is received, the cached token is invalidated and
the request is retried with a fresh token. Default is 1.
"""
self.client = client
self.base_url = base_url
self.max_auth_retries = max_auth_retries

def url(self, path: str) -> str:
api_base_url = urljoin(self.base_url, self.api_path.lstrip("/"))
Expand All @@ -23,21 +37,62 @@ def url(self, path: str) -> str:
return urljoin(api_base_url, path.lstrip("/"))

def make_request(self, *args, **kwargs):
"""
Make an HTTP request with automatic 401 retry.

When a 401 Unauthorized response is received, this method will:
1. Invalidate the cached authentication token
2. Retry the request (up to max_auth_retries times)
3. The retry will automatically fetch a fresh token via refresh or re-auth

Args:
*args: Positional arguments passed to client.request()
**kwargs: Keyword arguments passed to client.request()

Returns:
Response object from the HTTP client

Raises:
Api401Error: If request is unauthorized after all retries
Api403Error: If request is forbidden (no retry)
"""
if "scopes" not in kwargs:
kwargs["scopes"] = self.scopes
response = self.client.request(*args, **kwargs)

if response.status == 401:
raise Api401Error("Unauthorized to make this request.")
elif response.status == 403:
raise Api403Error(
(
"Not permitted to perform this action. "
"Please contact Satellite Vu for assistance."

for attempt in range(self.max_auth_retries + 1):
response = self.client.request(*args, **kwargs)

if response.status == 401:
if attempt < self.max_auth_retries:
self._invalidate_auth_token(kwargs.get("scopes"))
continue
else:
raise Api401Error("Unauthorized to make this request after retry.")
elif response.status == 403:
raise Api403Error(
(
"Not permitted to perform this action. "
"Please contact Satellite Vu for assistance."
)
)
)

return response
return response

def _invalidate_auth_token(self, scopes):
"""
Invalidate the cached authentication token.

This forces the next request to fetch a fresh token.

Args:
scopes: Scopes associated with the token to invalidate
"""
auth = next(
(v for k, v in self.client._auth.items() if self.base_url.startswith(k)),
None,
)
if auth:
auth.invalidate_token(scopes)

def deprecation_warning(self, new_cls):
simplefilter("always", DeprecationWarning)
Expand Down
179 changes: 167 additions & 12 deletions satellitevu/auth/auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from base64 import b64decode
from datetime import datetime
from datetime import datetime, timedelta
from hashlib import sha1
from json import loads
from logging import getLogger
Expand All @@ -15,14 +15,27 @@
logger = getLogger(__file__)


def is_expired_token(token: str) -> bool:
def is_expired_token(token: str, grace_period_seconds: int = 0) -> bool:
"""
Check if a JWT token is expired.

Args:
token: JWT token string
grace_period_seconds: Number of seconds before actual expiration to consider token expired.
This provides a safety buffer for network latency and clock skew.
Default is 0 (no grace period).

Returns:
True if token is expired (or within grace period of expiration), False otherwise
"""
json = b64decode(token.split(".")[1] + "==")
claims = loads(json)
if not claims or "exp" not in claims:
return False
exp = float(claims["exp"])
exp_dt = datetime.fromtimestamp(exp)
return exp_dt <= datetime.now()
grace_period = timedelta(seconds=grace_period_seconds)
return exp_dt <= datetime.now() + grace_period


class Auth:
Expand All @@ -34,6 +47,8 @@ class Auth:

auth_url: str
client: AbstractClient
grace_period_seconds: int
enable_refresh: bool

def __init__(
self,
Expand All @@ -44,34 +59,129 @@ def __init__(
cache: Optional[AbstractCache] = None,
auth_url: Optional[str] = None,
client: Optional[AbstractClient] = None,
grace_period_seconds: int = 60,
enable_refresh: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to migrate all clients first to support offline_access scope

):
"""
Initialize Auth handler for OAuth2 client credentials flow.

Args:
client_id: OAuth2 client ID
client_secret: OAuth2 client secret
audience: OAuth2 audience (defaults to API gateway URL)
cache: Cache implementation for storing tokens (defaults to AppDirCache)
auth_url: Authentication server base URL
client: HTTP client implementation (defaults to UrllibClient)
grace_period_seconds: Number of seconds before token expiration to consider it expired.
This provides a safety buffer for network latency and clock skew.
Default is 60 seconds.
enable_refresh: Enable refresh token flow. When enabled, the client will request
refresh tokens and use them to obtain new access tokens instead of
performing full re-authentication. Default is True.
"""
self.client_id = client_id
self.client_secret = client_secret
self.audience = audience or AUDIENCE

self.cache = cache or AppDirCache()
self.auth_url = auth_url or AUTH_URL
self.client = client or UrllibClient()
self.grace_period_seconds = grace_period_seconds
self.enable_refresh = enable_refresh

def token(self, scopes: Optional[List] = None) -> str:
def invalidate_token(self, scopes: Optional[List] = None):
"""
Invalidate cached token for given scopes.

This forces the next token() call to fetch a fresh token, either via
refresh token or full authentication.

Args:
scopes: Optional list of scopes whose cached token should be invalidated
"""
if not scopes:
scopes = []
cache_key = sha1(self.client_id.encode("utf-8")) # nosec B324
cache_key.update("".join(scopes).encode("utf-8"))
cache_key_digest = cache_key.hexdigest()

token = self.cache.load(cache_key.hexdigest())
# Save None to invalidate the access_token cache entry
self.cache.save(cache_key_digest, "", None)
logger.info("Invalidated cached token")

if not token or is_expired_token(token):
token = self._auth(scopes)
self.cache.save(cache_key.hexdigest(), token)
def token(self, scopes: Optional[List] = None) -> str:
"""
Get a valid access token, using cache or refresh token if possible.

return token
Args:
scopes: Optional list of OAuth2 scopes to request

def _auth(self, scopes: Optional[List] = None) -> str:
Returns:
Valid access token

Raises:
AuthError: If authentication or token refresh fails
"""
if not scopes:
scopes = []
cache_key = sha1(self.client_id.encode("utf-8")) # nosec B324
cache_key.update("".join(scopes).encode("utf-8"))
cache_key_digest = cache_key.hexdigest()

# Try to load both access and refresh tokens from cache
access_token, refresh_token = self.cache.load_tokens(cache_key_digest)

# Check if access token is valid
if access_token and not is_expired_token(
access_token, self.grace_period_seconds
):
return access_token

# Access token is expired or missing, try to refresh if we have a refresh token
if self.enable_refresh and refresh_token:
# Check if refresh token itself is expired (no grace period for refresh tokens)
if not is_expired_token(refresh_token, 0):
try:
logger.info("Access token expired, attempting refresh")
new_access_token = self._refresh(refresh_token, scopes)
# Keep the same refresh token (it doesn't change)
self.cache.save(cache_key_digest, new_access_token, refresh_token)
return new_access_token
except AuthError as e:
# Refresh failed, fall back to full authentication
logger.warning(
f"Token refresh failed: {e}. Falling back to full authentication"
)

# No refresh token, refresh failed, or refresh disabled - perform full authentication
access_token, refresh_token = self._auth(scopes)
self.cache.save(cache_key_digest, access_token, refresh_token)
return access_token

def _auth(self, scopes: Optional[List] = None) -> tuple[str, Optional[str]]:
"""
Perform OAuth2 client credentials authentication.

Args:
scopes: Optional list of scopes to request

Returns:
Tuple of (access_token, refresh_token). refresh_token may be None if not supported
by the server or if refresh tokens are disabled.

Raises:
AuthError: If authentication request fails
"""
if not scopes:
scopes = []
logger.info("Performing client_credential authentication")
token_url = urljoin(self.auth_url, "oauth/token")

# Add offline_access scope if refresh tokens are enabled
request_scopes = scopes.copy()
if self.enable_refresh and "offline_access" not in request_scopes:
request_scopes.append("offline_access")

response = self.client.post(
token_url,
headers={"content-type": "application/x-www-form-urlencoded"},
Expand All @@ -80,7 +190,7 @@ def _auth(self, scopes: Optional[List] = None) -> str:
"client_id": self.client_id,
"client_secret": self.client_secret,
"audience": self.audience,
"scope": " ".join(scopes),
"scope": " ".join(request_scopes),
},
)

Expand All @@ -91,8 +201,53 @@ def _auth(self, scopes: Optional[List] = None) -> str:
)
try:
payload = response.json()
return payload["access_token"]
access_token = payload["access_token"]
refresh_token = payload.get("refresh_token")
return (access_token, refresh_token)
except Exception:
raise AuthError(
"Unexpected response body for client_credential flow: " + response.text
)

def _refresh(self, refresh_token: str, scopes: Optional[List] = None) -> str:
"""
Refresh access token using a refresh token.

Args:
refresh_token: Valid refresh token
scopes: Optional list of scopes to request

Returns:
New access token

Raises:
AuthError: If refresh request fails
"""
if not scopes:
scopes = []
logger.info("Refreshing access token using refresh token")
token_url = urljoin(self.auth_url, "oauth/token")
response = self.client.post(
token_url,
headers={"content-type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"client_id": self.client_id,
"client_secret": self.client_secret,
"refresh_token": refresh_token,
"scope": " ".join(scopes),
},
)

if response.status != 200:
raise AuthError(
"Unexpected error code for refresh_token flow: "
f"{response.status} - {response.text}"
)
try:
payload = response.json()
return payload["access_token"]
except Exception:
raise AuthError(
"Unexpected response body for refresh_token flow: " + response.text
)
Loading
Loading