Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/fides/api/cryptography/schemas/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
JWE_PAYLOAD_SYSTEMS = "systems"
JWE_PAYLOAD_CONNECTIONS = "connections"
JWE_PAYLOAD_MONITORS = "monitors"
JWE_PAYLOAD_PASSWORD_RESET_AT = "password-reset-at"
14 changes: 14 additions & 0 deletions src/fides/api/models/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
JWE_PAYLOAD_CLIENT_ID,
JWE_PAYLOAD_CONNECTIONS,
JWE_PAYLOAD_MONITORS,
JWE_PAYLOAD_PASSWORD_RESET_AT,
JWE_PAYLOAD_ROLES,
JWE_PAYLOAD_SCOPES,
JWE_PAYLOAD_SYSTEMS,
Expand Down Expand Up @@ -148,8 +149,17 @@ def create_access_code_jwe(

Includes iat (issued-at) and exp (expires-at, Unix timestamp) for
server-side expiration enforcement and client-facing expiry info.

If the client is associated with a user who has a password_reset_at timestamp,
it is included in the payload for stateless token invalidation checks.
"""
now = datetime.now()

# Get password_reset_at if user is associated
password_reset_at = None
if self.user is not None and self.user.password_reset_at is not None:
password_reset_at = self.user.password_reset_at.isoformat()

payload = {
JWE_PAYLOAD_CLIENT_ID: self.id,
JWE_PAYLOAD_SCOPES: self.scopes,
Expand All @@ -160,6 +170,10 @@ def create_access_code_jwe(
JWE_PAYLOAD_CONNECTIONS: self.connections,
JWE_PAYLOAD_MONITORS: self.monitors,
}

if password_reset_at is not None:
payload[JWE_PAYLOAD_PASSWORD_RESET_AT] = password_reset_at

return generate_jwe(json.dumps(payload), encryption_key)

def credentials_valid(self, provided_secret: str, encoding: str = "UTF-8") -> bool:
Expand Down
49 changes: 45 additions & 4 deletions src/fides/api/oauth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
JWE_EXPIRES_AT,
JWE_ISSUED_AT,
JWE_PAYLOAD_CLIENT_ID,
JWE_PAYLOAD_PASSWORD_RESET_AT,
JWE_PAYLOAD_ROLES,
JWE_PAYLOAD_SCOPES,
)
Expand Down Expand Up @@ -243,6 +244,34 @@ def is_token_invalidated(issued_at: datetime, client: ClientDetail) -> bool:
return False


def is_token_invalidated_offline(
issued_at: datetime,
password_reset_at_str: Optional[str],
) -> bool:
"""
Check token invalidation using payload data only (no DB lookup).

This function enables stateless token validation by checking if a token
was issued before a password reset, using the password_reset_at value
embedded in the token payload itself rather than querying the database.

Args:
issued_at: When the token was issued
password_reset_at_str: ISO format string of password reset timestamp from token payload

Returns:
True if the token should be considered invalid (issued before password reset),
False otherwise (including when password_reset_at_str is None or invalid).
"""
if password_reset_at_str is None:
return False
try:
password_reset_at = datetime.fromisoformat(password_reset_at_str)
return issued_at < password_reset_at
except (TypeError, ValueError):
return False # Treat parse errors as non-invalidating


def _get_webhook_jwe_or_error(
security_scopes: SecurityScopes, authorization: str = Security(oauth2_scheme)
) -> WebhookJWE:
Expand Down Expand Up @@ -592,8 +621,14 @@ def extract_token_and_load_client(

# Invalidate tokens issued prior to the user's most recent password reset.
# This ensures any existing sessions are expired immediately after a password change.
if is_token_invalidated(issued_at_dt, client):
logger.debug("Auth token issued before latest password reset.")
# First try stateless check using payload data (for newer tokens with password-reset-at)
password_reset_at_str = token_data.get(JWE_PAYLOAD_PASSWORD_RESET_AT)
if is_token_invalidated_offline(issued_at_dt, password_reset_at_str):
logger.debug("Auth token issued before password reset (offline check).")
raise AuthorizationError(detail="Not Authorized for this action")
# Fall back to DB check for older tokens without password-reset-at in payload
if password_reset_at_str is None and is_token_invalidated(issued_at_dt, client):
logger.debug("Auth token issued before password reset (DB fallback).")
raise AuthorizationError(detail="Not Authorized for this action")

# Populate request-scoped context with the authenticated user identifier.
Expand Down Expand Up @@ -674,8 +709,14 @@ async def extract_token_and_load_client_async(

# Invalidate tokens issued prior to the user's most recent password reset.
# This ensures any existing sessions are expired immediately after a password change.
if is_token_invalidated(issued_at_dt, client):
logger.debug("Auth token issued before latest password reset.")
# First try stateless check using payload data (for newer tokens with password-reset-at)
password_reset_at_str = token_data.get(JWE_PAYLOAD_PASSWORD_RESET_AT)
if is_token_invalidated_offline(issued_at_dt, password_reset_at_str):
logger.debug("Auth token issued before password reset (offline check).")
raise AuthorizationError(detail="Not Authorized for this action")
# Fall back to DB check for older tokens without password-reset-at in payload
if password_reset_at_str is None and is_token_invalidated(issued_at_dt, client):
logger.debug("Auth token issued before password reset (DB fallback).")
raise AuthorizationError(detail="Not Authorized for this action")

# Populate request-scoped context with the authenticated user identifier.
Expand Down
44 changes: 43 additions & 1 deletion tests/ops/util/test_oauth_util.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from datetime import datetime

import pytest

from fides.api.common_exceptions import AuthorizationError
from fides.api.oauth.roles import OWNER, VIEWER
from fides.api.oauth.utils import get_root_client, roles_have_scopes
from fides.api.oauth.utils import (
get_root_client,
is_token_invalidated_offline,
roles_have_scopes,
)
from fides.common.scope_registry import POLICY_READ, USER_CREATE, USER_READ


Expand Down Expand Up @@ -54,3 +60,39 @@ class TestRolesHaveScopes:
def test_roles_have_scopes(self, roles, required_scopes, expected):
"""Test that roles_have_scopes correctly checks if roles have required scopes."""
assert roles_have_scopes(roles, required_scopes) == expected


class TestIsTokenInvalidatedOffline:
"""Tests for stateless token invalidation checks using payload data only."""

def test_is_token_invalidated_offline_no_reset(self):
"""Token valid when no password_reset_at in payload."""
assert is_token_invalidated_offline(datetime.now(), None) is False

def test_is_token_invalidated_offline_issued_after(self):
"""Token valid when issued after password reset."""
reset_time = datetime(2024, 1, 1, 12, 0, 0)
issued_time = datetime(2024, 1, 2, 12, 0, 0) # After reset
assert (
is_token_invalidated_offline(issued_time, reset_time.isoformat()) is False
)

def test_is_token_invalidated_offline_issued_before(self):
"""Token invalid when issued before password reset."""
reset_time = datetime(2024, 1, 2, 12, 0, 0)
issued_time = datetime(2024, 1, 1, 12, 0, 0) # Before reset
assert is_token_invalidated_offline(issued_time, reset_time.isoformat()) is True

def test_is_token_invalidated_offline_same_time(self):
"""Token valid when issued at exact same time as password reset."""
same_time = datetime(2024, 1, 1, 12, 0, 0)
# Not strictly before, so should be valid
assert is_token_invalidated_offline(same_time, same_time.isoformat()) is False

def test_is_token_invalidated_offline_invalid_format(self):
"""Graceful handling of invalid ISO format."""
assert is_token_invalidated_offline(datetime.now(), "not-a-date") is False

def test_is_token_invalidated_offline_empty_string(self):
"""Empty string treated as invalid format."""
assert is_token_invalidated_offline(datetime.now(), "") is False
Loading