diff --git a/changelog/7708-adopt-structured-cache-for-dsrs.yaml b/changelog/7708-adopt-structured-cache-for-dsrs.yaml new file mode 100644 index 00000000000..b1dad88aa26 --- /dev/null +++ b/changelog/7708-adopt-structured-cache-for-dsrs.yaml @@ -0,0 +1,4 @@ +type: Changed +description: Migrated DSR workflows to use structured caching mechanism with secondary index, ensuring backward compatibility with legacy cache keys for in-flight requests during deployment. +pr: 7708 +labels: [] diff --git a/src/fides/api/models/privacy_request/consent.py b/src/fides/api/models/privacy_request/consent.py index 3bec2a35a28..95c3db7becd 100644 --- a/src/fides/api/models/privacy_request/consent.py +++ b/src/fides/api/models/privacy_request/consent.py @@ -20,7 +20,6 @@ CustomPrivacyRequestField as CustomPrivacyRequestFieldSchema, ) from fides.api.schemas.redis_cache import IdentityBase -from fides.api.util.cache import FidesopsRedis, get_cache from fides.api.util.identity_verification import IdentityVerificationMixin from fides.config import CONFIG @@ -95,12 +94,6 @@ class ConsentRequest(IdentityVerificationMixin, Base): privacy_request_id = Column(String, ForeignKey("privacyrequest.id"), nullable=True) privacy_request = relationship("PrivacyRequest") - def get_cached_identity_data(self) -> Dict[str, Any]: - """Retrieves any identity data pertaining to this request from the cache.""" - cache: FidesopsRedis = get_cache() - keys = cache.get_keys_by_prefix(f"id-{self.id}-identity-") - return {key.split("-")[-1]: cache.get(key) for key in keys} - def verify_identity( self, db: Session, diff --git a/src/fides/api/models/privacy_request/privacy_request.py b/src/fides/api/models/privacy_request/privacy_request.py index a52a4b7bf97..1c25dd969cf 100644 --- a/src/fides/api/models/privacy_request/privacy_request.py +++ b/src/fides/api/models/privacy_request/privacy_request.py @@ -106,13 +106,8 @@ from fides.api.tasks import celery_app from fides.api.util.cache import ( FidesopsRedis, - get_all_cache_keys_for_privacy_request, - get_async_task_tracking_cache_key, get_cache, - get_custom_privacy_request_field_cache_key, - get_drp_request_body_cache_key, - get_encryption_cache_key, - get_identity_cache_key, + get_dsr_cache_store, ) from fides.api.util.collection_util import Row from fides.api.util.constants import API_DATE_FORMAT @@ -471,10 +466,8 @@ def clear_cached_values(self) -> None: Clears all cached values associated with this privacy request from Redis. """ logger.info(f"Clearing cached values for privacy request {self.id}") - cache: FidesopsRedis = get_cache() - all_keys = get_all_cache_keys_for_privacy_request(privacy_request_id=self.id) - for key in all_keys: - cache.delete(key) + store = get_dsr_cache_store(self.id) + store.clear() def delete(self, db: Session) -> None: """ @@ -506,19 +499,22 @@ def cache_identity( self, identity: Union[Identity, Dict[str, LabeledIdentity]] ) -> None: """Sets the identity's values at their specific locations in the Fides app cache""" - cache: FidesopsRedis = get_cache() - if isinstance(identity, dict): identity = Identity(**identity) identity_dict: Dict[str, Any] = identity.labeled_dict() - for key, value in identity_dict.items(): - if value is not None: - cache.set_with_autoexpire( - get_identity_cache_key(self.id, key), - FidesopsRedis.encode_obj(value), - ) + store = get_dsr_cache_store(self.id) + # Encode values for Redis storage + encoded_dict = { + key: FidesopsRedis.encode_obj(value) + for key, value in identity_dict.items() + if value is not None + } + store.cache_identity_data( + encoded_dict, + expire_seconds=CONFIG.redis.default_ttl_seconds, + ) def cache_custom_privacy_request_fields( self, @@ -534,13 +530,17 @@ def cache_custom_privacy_request_fields( return if CONFIG.execution.allow_custom_privacy_request_fields_in_request_execution: - cache: FidesopsRedis = get_cache() - for key, item in custom_privacy_request_fields.items(): - if item is not None: - cache.set_with_autoexpire( - get_custom_privacy_request_field_cache_key(self.id, key), - json.dumps(item.value, cls=CustomJSONEncoder), - ) + store = get_dsr_cache_store(self.id) + # Encode values for Redis storage + encoded_fields = { + key: json.dumps(item.value, cls=CustomJSONEncoder) + for key, item in custom_privacy_request_fields.items() + if item is not None + } + store.cache_custom_fields( + encoded_fields, + expire_seconds=CONFIG.redis.default_ttl_seconds, + ) else: logger.info( "Custom fields from privacy request {}, but config setting 'CONFIG.execution.allow_custom_privacy_request_fields_in_request_execution' is set to false and prevents their usage.", @@ -681,14 +681,20 @@ def verify_identity(self, db: Session, provided_code: str) -> "PrivacyRequest": def get_cached_encryption_key(self) -> Optional[str]: """Gets the cached encryption key for this privacy request.""" - cache: FidesopsRedis = get_cache() - encryption_key = cache.get(get_encryption_cache_key(self.id, "key")) - return encryption_key + store = get_dsr_cache_store(self.id) + raw = store.get_encryption("key") + if raw is None: + return None + if isinstance(raw, bytes): + return raw.decode(CONFIG.security.encoding) + return str(raw) def get_cached_task_id(self) -> Optional[str]: """Gets the cached task ID for this privacy request.""" - cache: FidesopsRedis = get_cache() - task_id = cache.get(get_async_task_tracking_cache_key(self.id)) + store = get_dsr_cache_store(self.id) + task_id = store.get_async_execution() + if isinstance(task_id, bytes): + return task_id.decode(CONFIG.security.encoding) return task_id def get_async_execution_task(self) -> Optional[AsyncResult]: @@ -698,32 +704,35 @@ def get_async_execution_task(self) -> Optional[AsyncResult]: return res def cache_drp_request_body(self, drp_request_body: DrpPrivacyRequestCreate) -> None: - """Sets the identity's values at their specific locations in the Fides app cache""" - cache: FidesopsRedis = get_cache() + """Sets the DRP request body values at their specific locations in the Fides app cache""" drp_request_body_dict: Dict[str, Any] = dict(drp_request_body) + + # Serialize complex objects to repr format for storage + serialized_body: Dict[str, Any] = {} for key, value in drp_request_body_dict.items(): if value is not None: - # handle nested dict/objects + # Handle nested dict/objects if not isinstance(value, (bytes, str, int, float)): - cache.set_with_autoexpire( - get_drp_request_body_cache_key(self.id, key), - repr(value), - ) + serialized_body[key] = repr(value) else: - cache.set_with_autoexpire( - get_drp_request_body_cache_key(self.id, key), - value, - ) + serialized_body[key] = value + + store = get_dsr_cache_store(self.id) + store.cache_drp_request_body( + serialized_body, + expire_seconds=CONFIG.redis.default_ttl_seconds, + ) def cache_encryption(self, encryption_key: Optional[str] = None) -> None: """Sets the encryption key in the Fides app cache if provided""" if not encryption_key: return - cache: FidesopsRedis = get_cache() - cache.set_with_autoexpire( - get_encryption_cache_key(self.id, "key"), + store = get_dsr_cache_store(self.id) + store.write_encryption( + "key", encryption_key, + expire_seconds=CONFIG.redis.default_ttl_seconds, ) def persist_masking_secrets( @@ -745,53 +754,43 @@ def persist_masking_secrets( }, ) - def identity_prefix_cache_and_keys(self) -> Tuple[str, FidesopsRedis, List[str]]: - """Returns the prefix and cache keys for the identity data for this request""" - prefix = f"id-{self.id}-identity-*" - cache: FidesopsRedis = get_cache() - keys = cache.get_keys_by_prefix(f"id-{self.id}-identity-") - return prefix, cache, keys - def verify_cache_for_identity_data(self) -> bool: """Verifies if the identity data is cached for this request""" - _, _, keys = self.identity_prefix_cache_and_keys() - return len(keys) > 0 + store = get_dsr_cache_store(self.id) + return store.has_cached_identity_data() def get_cached_identity_data(self) -> Dict[str, Any]: """Retrieves any identity data pertaining to this request from the cache""" - result: Dict[str, Any] = {} - prefix, cache, keys = self.identity_prefix_cache_and_keys() + store = get_dsr_cache_store(self.id) + result = store.get_cached_identity_data() - if not keys: + if not result: logger.debug(f"Cache miss for request {self.id}, falling back to DB") identity = self.get_persisted_identity() self.cache_identity(identity) - keys = cache.get_keys_by_prefix(f"id-{self.id}-identity-") - - for key in keys: - value = cache.get(key) - if value: - try: - # try parsing the value as JSON - parsed_value = json.loads(value) - except json.JSONDecodeError: - # if parsing as JSON fails, assume it's a string. - # this is purely for backward compatibility: to ensure - # that identity data stored pre-2.34.0 in the "old" format - # can still be correctly retrieved from the cache. - parsed_value = value - result[key.split("-")[-1]] = parsed_value - return result + result = store.get_cached_identity_data() + + # Parse JSON values for backward compatibility + parsed_result: Dict[str, Any] = {} + for key, value in result.items(): + try: + # try parsing the value as JSON + parsed_result[key] = json.loads(value) + except json.JSONDecodeError: + # if parsing as JSON fails, assume it's a string. + # this is purely for backward compatibility: to ensure + # that identity data stored pre-2.34.0 in the "old" format + # can still be correctly retrieved from the cache. + parsed_result[key] = value + + return parsed_result def get_cached_custom_privacy_request_fields(self) -> Dict[str, Any]: """Retrieves any custom fields pertaining to this request from the cache""" - result: Dict[str, Any] = {} - prefix = f"id-{self.id}-custom-privacy-request-field-" + store = get_dsr_cache_store(self.id) + result = store.get_cached_custom_fields() - cache: FidesopsRedis = get_cache() - keys = cache.get_keys_by_prefix(prefix) - - if not keys: + if not result: logger.debug(f"Cache miss for request {self.id}, falling back to DB") custom_privacy_request_fields = ( self.get_persisted_custom_privacy_request_fields() @@ -802,13 +801,17 @@ def get_cached_custom_privacy_request_fields(self) -> Dict[str, Any]: for key, value in custom_privacy_request_fields.items() } ) - keys = cache.get_keys_by_prefix(prefix) + result = store.get_cached_custom_fields() + + # Parse JSON values + parsed_result: Dict[str, Any] = {} + for key, value in result.items(): + try: + parsed_result[key] = json.loads(value) + except json.JSONDecodeError: + parsed_result[key] = value - for key in keys: - value = cache.get(key) - if value: - result[key.split("-")[-1]] = json.loads(value) - return result + return parsed_result def cache_email_connector_template_contents( self, diff --git a/src/fides/api/models/privacy_request/request_task.py b/src/fides/api/models/privacy_request/request_task.py index 4cb4b95bbaa..10131993022 100644 --- a/src/fides/api/models/privacy_request/request_task.py +++ b/src/fides/api/models/privacy_request/request_task.py @@ -28,10 +28,8 @@ from fides.api.schemas.base_class import FidesSchema from fides.api.schemas.policy import ActionType from fides.api.util.cache import ( - FidesopsRedis, celery_tasks_in_flight, - get_async_task_tracking_cache_key, - get_cache, + get_dsr_cache_store, ) from fides.api.util.collection_util import Row from fides.config import CONFIG @@ -247,8 +245,10 @@ def allowed_action_types(cls) -> List[str]: def get_cached_task_id(self) -> Optional[str]: """Gets the cached celery task ID for this request task.""" - cache: FidesopsRedis = get_cache() - task_id = cache.get(get_async_task_tracking_cache_key(self.id)) + store = get_dsr_cache_store(self.id) + task_id = store.get_async_execution() + if isinstance(task_id, bytes): + return task_id.decode(CONFIG.security.encoding) return task_id def cleanup_external_storage(self) -> None: diff --git a/src/fides/api/service/privacy_request/request_service.py b/src/fides/api/service/privacy_request/request_service.py index b0a488fa708..6d4cc8aa571 100644 --- a/src/fides/api/service/privacy_request/request_service.py +++ b/src/fides/api/service/privacy_request/request_service.py @@ -33,8 +33,8 @@ from fides.api.util.cache import ( FidesopsRedis, celery_tasks_in_flight, - get_async_task_tracking_cache_key, get_cache, + get_dsr_cache_store, get_privacy_request_retry_count, increment_privacy_request_retry_count, reset_privacy_request_retry_count, @@ -336,9 +336,11 @@ def get_cached_task_id(entity_id: str) -> Optional[str]: Raises Exception if cache operations fail, allowing callers to handle cache failures appropriately. """ - cache: FidesopsRedis = get_cache() try: - task_id = cache.get(get_async_task_tracking_cache_key(entity_id)) + store = get_dsr_cache_store(entity_id) + task_id = store.get_async_execution() + if isinstance(task_id, bytes): + return task_id.decode(CONFIG.security.encoding) return task_id except Exception as exc: logger.error(f"Failed to get cached task ID for entity {entity_id}: {exc}") diff --git a/src/fides/api/tasks/encryption_utils.py b/src/fides/api/tasks/encryption_utils.py index 9b892436db2..060a055f875 100644 --- a/src/fides/api/tasks/encryption_utils.py +++ b/src/fides/api/tasks/encryption_utils.py @@ -2,7 +2,7 @@ from typing import Optional, Union from fides.api.cryptography.cryptographic_util import bytes_to_b64_str -from fides.api.util.cache import get_cache, get_encryption_cache_key +from fides.api.util.cache import get_dsr_cache_store from fides.api.util.encryption.aes_gcm_encryption_scheme import ( encrypt_to_bytes_verify_secrets_length, ) @@ -19,15 +19,17 @@ def encrypt_access_request_results(data: Union[str, bytes], request_id: str) -> Returns: str: The encrypted data as a string """ - cache = get_cache() - encryption_cache_key = get_encryption_cache_key( - privacy_request_id=request_id, - encryption_attr="key", - ) if isinstance(data, bytes): data = data.decode(CONFIG.security.encoding) - encryption_key: Optional[str] = cache.get(encryption_cache_key) + store = get_dsr_cache_store(request_id) + raw = store.get_encryption("key") + if raw is None: + return data + if isinstance(raw, bytes): + encryption_key = raw.decode(CONFIG.security.encoding) + else: + encryption_key = str(raw) if not encryption_key: return data diff --git a/src/fides/api/util/cache.py b/src/fides/api/util/cache.py index 66d51322c9b..067ca796f0c 100644 --- a/src/fides/api/util/cache.py +++ b/src/fides/api/util/cache.py @@ -28,6 +28,8 @@ celery_app, ) from fides.api.util.custom_json_encoder import CustomJSONEncoder, _custom_decoder +from fides.common.cache.dsr_store import DSRCacheStore +from fides.common.cache.manager import RedisCacheManager from fides.config import CONFIG # This constant represents every type a redis key may contain, and can be @@ -318,6 +320,20 @@ def get_cache() -> FidesopsRedis: return _connection +def get_redis_cache_manager() -> RedisCacheManager: + """Return a RedisCacheManager wrapping the default Redis connection.""" + return RedisCacheManager(get_cache()) + + +def get_dsr_cache_store(dsr_id: str) -> DSRCacheStore: + """Return a DSRCacheStore scoped to a single privacy request.""" + return DSRCacheStore( + dsr_id, + get_redis_cache_manager(), + default_ttl_seconds=CONFIG.redis.default_ttl_seconds, + ) + + def get_read_only_cache() -> FidesopsRedis: """ Return a singleton connection to the read-only Redis cache. @@ -406,6 +422,12 @@ def get_all_cache_keys_for_privacy_request(privacy_request_id: str) -> List[Any] def get_async_task_tracking_cache_key(privacy_request_id: str) -> str: + """Return the *legacy* Redis key for async-execution tracking. + + Prefer ``get_dsr_cache_store(dsr_id).get_async_execution()`` for reads and + ``cache_task_tracking_key()`` for writes — both route through the + DSRCacheStore which handles legacy fallback automatically. + """ return f"id-{privacy_request_id}-async-execution" @@ -422,12 +444,11 @@ def cache_task_tracking_key(request_id: str, celery_task_id: str) -> None: :return: None """ - cache: FidesopsRedis = get_cache() - try: - cache.set_with_autoexpire( - get_async_task_tracking_cache_key(request_id), + store = get_dsr_cache_store(request_id) + store.write_async_execution( celery_task_id, + expire_seconds=CONFIG.redis.default_ttl_seconds, ) except DataError: logger.debug( diff --git a/src/fides/common/cache/__init__.py b/src/fides/common/cache/__init__.py index 4757b8e8f66..66441b42367 100644 --- a/src/fides/common/cache/__init__.py +++ b/src/fides/common/cache/__init__.py @@ -8,13 +8,8 @@ """ -from contextlib import contextmanager -from typing import Iterator - -from fides.common.cache.dsr_store import ( - DSR_KEY_PREFIX, - DSRCacheStore, -) +from fides.common.cache.dsr_store import DSRCacheStore +from fides.common.cache.key_mapping import DSR_KEY_PREFIX from fides.common.cache.manager import ( INDEX_KEY_PREFIX, RedisCacheManager, diff --git a/src/fides/common/cache/dsr_store.py b/src/fides/common/cache/dsr_store.py index 921dcee4094..35257d0207d 100644 --- a/src/fides/common/cache/dsr_store.py +++ b/src/fides/common/cache/dsr_store.py @@ -14,7 +14,7 @@ later if we want to avoid index consistency concerns. """ -from typing import List, Optional, Union +from typing import Any, Callable, Dict, Optional, Union from redis import Redis @@ -46,32 +46,38 @@ class DSRCacheStore: def __init__( self, + dsr_id: str, cache_manager: RedisCacheManager, *, + default_ttl_seconds: int = 3600, backfill_index_on_legacy_read: bool = True, migrate_legacy_on_read: bool = True, ) -> None: """ Args: + dsr_id: The privacy request ID this store is scoped to. cache_manager: RedisCacheManager (e.g. from get_redis_cache_manager()). + default_ttl_seconds: Fallback TTL for migrated keys when the legacy + key has no expiration. Default 3600s (1 hour). backfill_index_on_legacy_read: When listing keys and we fall back to KEYS for legacy keys, add those keys to the index. Default True. migrate_legacy_on_read: When a get finds value in legacy key only, write to new key, delete legacy key, add new key to index. Default True. """ + self._dsr_id = dsr_id self._manager = cache_manager self._redis: Redis = cache_manager.redis + self._default_ttl = default_ttl_seconds self._backfill = backfill_index_on_legacy_read self._migrate_on_read = migrate_legacy_on_read def write( self, - dsr_id: str, field_type: str, field_key: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """ Low-level write: set dsr:{dsr_id}:{field_type}:{field_key} and add to index. @@ -79,220 +85,362 @@ def write( stays in one place. """ part = f"{field_type}:{field_key}" if field_key else field_type - return self.set(dsr_id, part, value, expire_seconds) + return self.set(part, value, expire_seconds) def get_with_legacy( self, - dsr_id: str, part: str, legacy_key: str, ) -> Optional[Union[str, bytes]]: """ Get value for part; if missing, try legacy_key. If found in legacy only and migrate_legacy_on_read, copy to new key, delete legacy, add to index. + Propagates the legacy key's remaining TTL to the new key. """ - val = self._redis.get(_dsr_key(dsr_id, part)) + new_key = _dsr_key(self._dsr_id, part) + val = self._redis.get(new_key) if val is not None: return val val = self._redis.get(legacy_key) if val is None: - return None + # Re-check: another reader may have migrated between our two GETs + return self._redis.get(new_key) if self._migrate_on_read: - self.set(dsr_id, part, val) + ttl = self._redis.ttl(legacy_key) + expire = ttl if ttl > 0 else self._default_ttl + self.set(part, val, expire) self._redis.delete(legacy_key) return val - def get(self, dsr_id: str, part: str) -> Optional[Union[str, bytes]]: + def get(self, part: str) -> Optional[Union[str, bytes]]: """Get a value for the given DSR and part. Returns None if missing.""" - return self._redis.get(_dsr_key(dsr_id, part)) + return self._redis.get(_dsr_key(self._dsr_id, part)) def set( self, - dsr_id: str, part: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """ Set a value for the given DSR and part. Registers the key in the DSR index. """ - key = _dsr_key(dsr_id, part) + key = _dsr_key(self._dsr_id, part) return self._manager.set_with_index( - key, value, _dsr_index_prefix(dsr_id), expire_seconds + key, value, _dsr_index_prefix(self._dsr_id), expire_seconds ) - def delete(self, dsr_id: str, part: str) -> None: + def delete(self, part: str) -> None: """Delete a single part and remove it from the DSR index.""" - key = _dsr_key(dsr_id, part) - self._manager.delete_key_and_remove_from_index(key, _dsr_index_prefix(dsr_id)) + key = _dsr_key(self._dsr_id, part) + self._manager.delete_key_and_remove_from_index( + key, _dsr_index_prefix(self._dsr_id) + ) + + # --- Shared get/has helpers --- + + def _get_cached_by_type( + self, + new_infix: str, + legacy_infix: str, + getter: Callable[[str], Optional[Union[str, bytes]]], + ) -> Dict[str, Any]: + """Shared implementation for get_cached_custom_fields/identity_data/drp_request_body.""" + result: Dict[str, Any] = {} + for key in self.get_all_keys(): + if new_infix in key: + field = key.split(":")[-1] + elif legacy_infix in key: + field = key.split(legacy_infix, 1)[-1] + else: + continue + value = getter(field) + if value: # Intentionally drops empty/falsy — matches legacy behavior + result[field] = value + return result + + def _has_cached_by_type(self, new_infix: str, legacy_infix: str) -> bool: + """Shared implementation for has_cached_* methods.""" + return any(new_infix in k or legacy_infix in k for k in self.get_all_keys()) # --- Convenience: custom privacy request fields --- def write_custom_field( self, - dsr_id: str, field_key: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write a custom privacy request field. New key: dsr:{id}:custom_field:{field_key}.""" - return self.write(dsr_id, "custom_field", field_key, value, expire_seconds) + return self.write("custom_field", field_key, value, expire_seconds) - def get_custom_field( - self, dsr_id: str, field_key: str - ) -> Optional[Union[str, bytes]]: + def get_custom_field(self, field_key: str) -> Optional[Union[str, bytes]]: """Get custom field; reads from legacy id-{id}-custom-privacy-request-field-{key} if needed.""" part = f"custom_field:{field_key}" return self.get_with_legacy( - dsr_id, part, KeyMapper.custom_field(dsr_id, field_key)[1] + part, KeyMapper.custom_field(self._dsr_id, field_key)[1] + ) + + def cache_custom_fields( + self, + custom_fields: Dict[str, Any], + expire_seconds: int, + ) -> None: + """ + Cache all custom privacy request fields for a DSR. + + Writes each non-None field to dsr:{id}:custom_field:{field_key} format. + """ + for key, value in custom_fields.items(): + if value is not None: + self.write_custom_field(key, value, expire_seconds) + + def get_cached_custom_fields(self) -> Dict[str, Any]: + """ + Retrieve all cached custom fields for a DSR. + + Returns dict with custom field values. Automatically migrates legacy keys on read. + Returns empty dict if no custom fields cached. + """ + return self._get_cached_by_type( + ":custom_field:", + "-custom-privacy-request-field-", + self.get_custom_field, + ) + + def has_cached_custom_fields(self) -> bool: + """ + Check if any custom fields are cached for this DSR. + + Returns True if any custom field keys exist (legacy or new format). + """ + return self._has_cached_by_type( + ":custom_field:", "-custom-privacy-request-field-" ) # --- Convenience: identity --- def write_identity( self, - dsr_id: str, attr: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write an identity attribute. New key: dsr:{id}:identity:{attr}.""" - return self.write(dsr_id, "identity", attr, value, expire_seconds) + return self.write("identity", attr, value, expire_seconds) - def get_identity(self, dsr_id: str, attr: str) -> Optional[Union[str, bytes]]: + def get_identity(self, attr: str) -> Optional[Union[str, bytes]]: """Get identity attribute; reads from legacy id-{id}-identity-{attr} if needed.""" part = f"identity:{attr}" - return self.get_with_legacy(dsr_id, part, KeyMapper.identity(dsr_id, attr)[1]) + return self.get_with_legacy(part, KeyMapper.identity(self._dsr_id, attr)[1]) + + def cache_identity_data( + self, + identity_dict: Dict[str, Any], + expire_seconds: int, + ) -> None: + """ + Cache all identity attributes for a DSR. + + Writes each non-None attribute to dsr:{id}:identity:{attr} format. + """ + for key, value in identity_dict.items(): + if value is not None: + self.write_identity(key, value, expire_seconds) + + def get_cached_identity_data(self) -> Dict[str, Any]: + """ + Retrieve all cached identity data for a DSR. + + Returns dict with identity attributes. Automatically migrates legacy keys on read. + Returns empty dict if no identity data cached. + """ + return self._get_cached_by_type(":identity:", "-identity-", self.get_identity) + + def has_cached_identity_data(self) -> bool: + """ + Check if any identity data is cached for this DSR. + + Returns True if any identity keys exist (legacy or new format). + """ + return self._has_cached_by_type(":identity:", "-identity-") # --- Convenience: encryption --- def write_encryption( self, - dsr_id: str, attr: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write an encryption attribute. New key: dsr:{id}:encryption:{attr}.""" - return self.write(dsr_id, "encryption", attr, value, expire_seconds) + return self.write("encryption", attr, value, expire_seconds) - def get_encryption(self, dsr_id: str, attr: str) -> Optional[Union[str, bytes]]: + def get_encryption(self, attr: str) -> Optional[Union[str, bytes]]: """Get encryption attribute; reads from legacy id-{id}-encryption-{attr} if needed.""" part = f"encryption:{attr}" - return self.get_with_legacy(dsr_id, part, KeyMapper.encryption(dsr_id, attr)[1]) + return self.get_with_legacy(part, KeyMapper.encryption(self._dsr_id, attr)[1]) # --- Convenience: DRP request body --- def write_drp( self, - dsr_id: str, attr: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write DRP request body attribute. New key: dsr:{id}:drp:{attr}.""" - return self.write(dsr_id, "drp", attr, value, expire_seconds) + return self.write("drp", attr, value, expire_seconds) - def get_drp(self, dsr_id: str, attr: str) -> Optional[Union[str, bytes]]: + def get_drp(self, attr: str) -> Optional[Union[str, bytes]]: """Get DRP attribute; reads from legacy id-{id}-drp-{attr} if needed.""" part = f"drp:{attr}" - return self.get_with_legacy(dsr_id, part, KeyMapper.drp(dsr_id, attr)[1]) + return self.get_with_legacy(part, KeyMapper.drp(self._dsr_id, attr)[1]) + + def cache_drp_request_body( + self, + drp_body: Dict[str, Any], + expire_seconds: int, + ) -> None: + """ + Cache all DRP request body fields for a DSR. + Writes each non-None field to dsr:{id}:drp:{field_key} format. + """ + for key, value in drp_body.items(): + if value is not None: + self.write_drp(key, value, expire_seconds) + + def get_cached_drp_request_body(self) -> Dict[str, Any]: + """ + Retrieve all cached DRP request body data for a DSR. + Returns dict with DRP fields. Automatically migrates legacy keys on read. + Returns empty dict if no DRP data cached. + """ + return self._get_cached_by_type(":drp:", "-drp-", self.get_drp) + + def has_cached_drp_request_body(self) -> bool: + """ + Check if any DRP request body data is cached for this DSR. + Checks both new and legacy key formats. + """ + return self._has_cached_by_type(":drp:", "-drp-") # --- Convenience: masking secret --- def write_masking_secret( self, - dsr_id: str, strategy: str, secret_type: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write masking secret. New key: dsr:{id}:masking_secret:{strategy}:{secret_type}.""" part = f"masking_secret:{strategy}:{secret_type}" - return self.set(dsr_id, part, value, expire_seconds) + return self.set(part, value, expire_seconds) def get_masking_secret( self, - dsr_id: str, strategy: str, secret_type: str, ) -> Optional[Union[str, bytes]]: """Get masking secret; reads from legacy id-{id}-masking-secret-{strategy}-{type} if needed.""" part = f"masking_secret:{strategy}:{secret_type}" return self.get_with_legacy( - dsr_id, part, - KeyMapper.masking_secret(dsr_id, strategy, secret_type)[1], + KeyMapper.masking_secret(self._dsr_id, strategy, secret_type)[1], ) # --- Convenience: async execution (single value per DSR) --- def write_async_execution( self, - dsr_id: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write async task id. New key: dsr:{id}:async_execution.""" - return self.write(dsr_id, "async_execution", "", value, expire_seconds) + return self.write("async_execution", "", value, expire_seconds) - def get_async_execution(self, dsr_id: str) -> Optional[Union[str, bytes]]: + def get_async_execution(self) -> Optional[Union[str, bytes]]: """Get async task id; reads from legacy id-{id}-async-execution if needed.""" part = "async_execution" - return self.get_with_legacy(dsr_id, part, KeyMapper.async_execution(dsr_id)[1]) + return self.get_with_legacy(part, KeyMapper.async_execution(self._dsr_id)[1]) # --- Convenience: retry count --- def write_retry_count( self, - dsr_id: str, value: RedisValue, - expire_seconds: Optional[int] = None, + expire_seconds: int, ) -> Optional[bool]: """Write privacy request retry count. New key: dsr:{id}:retry_count.""" - return self.write(dsr_id, "retry_count", "", value, expire_seconds) + return self.write("retry_count", "", value, expire_seconds) - def get_retry_count(self, dsr_id: str) -> Optional[Union[str, bytes]]: + def get_retry_count(self) -> Optional[Union[str, bytes]]: """Get retry count; reads from legacy id-{id}-privacy-request-retry-count if needed.""" part = "retry_count" - return self.get_with_legacy(dsr_id, part, KeyMapper.retry_count(dsr_id)[1]) + return self.get_with_legacy(part, KeyMapper.retry_count(self._dsr_id)[1]) - # --- List / clear (unchanged) --- + # --- List / clear --- - def get_all_keys(self, dsr_id: str) -> List[str]: + def get_all_keys(self) -> list[str]: """ Return all cache keys for this DSR. - Uses the index first; if empty, falls back to SCAN for legacy keys - and optionally backfills the index. + + Uses the index first. If a migration flag confirms no legacy keys remain, + returns index contents directly. Otherwise, does a one-time SCAN to find + legacy stragglers, backfills them into the index, and sets the migration + flag so future calls skip the SCAN. """ - index_prefix = _dsr_index_prefix(dsr_id) + index_prefix = _dsr_index_prefix(self._dsr_id) keys = self._manager.get_keys_by_index(index_prefix) - if keys: + + # If we've already confirmed no legacy keys remain, index is authoritative + migration_key = f"__migrated:{self._dsr_id}" + if keys and self._redis.exists(migration_key): return keys - legacy_keys = list(self._redis.scan_iter(match=f"*{dsr_id}*", count=500)) - if not legacy_keys: + + # SCAN for all keys (one-time per DSR until migration confirmed) + # Filter out internal keys (__migrated:, __idx:) that match the SCAN pattern + scanned_keys = [ + k + for k in self._redis.scan_iter(match=f"*{self._dsr_id}*", count=500) + if not k.startswith("__migrated:") and not k.startswith("__idx:") + ] + indexed = set(keys) + scanned_set = set(scanned_keys) + all_keys = list(indexed | scanned_set) if keys else scanned_keys + + if not all_keys: return [] + if self._backfill: - for k in legacy_keys: - self._manager.add_key_to_index(index_prefix, k) - return list(legacy_keys) + for k in scanned_keys: + if k not in indexed: + self._manager.add_key_to_index(index_prefix, k) + + # If index existed and no scanned keys found outside it, mark as migrated + if keys and not (scanned_set - indexed): + self._redis.setex(migration_key, 86400, "1") # 24h TTL - def clear(self, dsr_id: str) -> None: + return all_keys + + def clear(self) -> None: """ Delete all cache keys for this DSR and remove the index. Always uses SCAN to find all keys (both indexed and legacy) to ensure - complete cleanup in mixed-key scenarios. + complete cleanup in mixed-key scenarios. Does a second SCAN pass to + catch keys written by concurrent migrations between the first SCAN + and DELETE. """ - # Use SCAN to find ALL keys (indexed + legacy) - all_keys_via_scan = list(self._redis.scan_iter(match=f"*{dsr_id}*", count=500)) - - index_prefix = _dsr_index_prefix(dsr_id) - - # Delete all found keys in batch - if all_keys_via_scan: - self._redis.delete(*all_keys_via_scan) - - # Delete the index itself + all_keys = list(self._redis.scan_iter(match=f"*{self._dsr_id}*", count=500)) + index_prefix = _dsr_index_prefix(self._dsr_id) + if all_keys: + self._redis.delete(*all_keys) self._manager.delete_index(index_prefix) + # Invalidate migration flag so future reads re-scan + self._redis.delete(f"__migrated:{self._dsr_id}") + # Second pass: catch keys written by concurrent migrations + stragglers = list(self._redis.scan_iter(match=f"*{self._dsr_id}*", count=500)) + if stragglers: + self._redis.delete(*stragglers) diff --git a/src/fides/common/cache/manager.py b/src/fides/common/cache/manager.py index 8895c567c1a..8b560ea3486 100644 --- a/src/fides/common/cache/manager.py +++ b/src/fides/common/cache/manager.py @@ -6,9 +6,7 @@ __idx:{index_prefix}; members are the actual cache key names. """ -from typing import List, Optional, Union - -from redis import Redis +from typing import Any, List, Optional, Union # Redis key prefix for index sets. Index key = INDEX_KEY_PREFIX + index_prefix. INDEX_KEY_PREFIX = "__idx:" @@ -31,10 +29,10 @@ class RedisCacheManager: set/delete helpers). """ - def __init__(self, redis_client: Redis) -> None: + def __init__(self, redis_client: Any) -> None: """ Args: - redis_client: Any Redis client (e.g. FidesopsRedis from get_cache()). + redis_client: redis.Redis, RedisCluster, or FidesopsRedis (delegates to underlying client). """ self._redis = redis_client @@ -131,6 +129,6 @@ def delete_key_and_remove_from_index( pipe.execute() @property - def redis(self) -> Redis: + def redis(self) -> Any: """Access the underlying Redis client for operations not on the manager.""" return self._redis diff --git a/tests/common/cache/conftest.py b/tests/common/cache/conftest.py index 3c9a57a610d..1ca2f11792e 100644 --- a/tests/common/cache/conftest.py +++ b/tests/common/cache/conftest.py @@ -3,10 +3,47 @@ real FastAPI app, DB, and Celery worker are not started when running only these tests. """ +import uuid from unittest.mock import MagicMock import pytest +from fides.common.cache.dsr_store import DSRCacheStore +from fides.common.cache.manager import RedisCacheManager +from tests.common.cache.mock_redis import create_mock_redis + +# --- Shared cache test fixtures --- + + +@pytest.fixture +def mock_redis(): + """In-memory autospec'd Redis mock.""" + return create_mock_redis() + + +@pytest.fixture +def manager(mock_redis) -> RedisCacheManager: + """RedisCacheManager backed by mock Redis.""" + return RedisCacheManager(mock_redis) + + +@pytest.fixture +def dsr_store(manager: RedisCacheManager) -> DSRCacheStore: + """DSRCacheStore backed by mock Redis, scoped to default 'pr-1' ID.""" + return DSRCacheStore("pr-1", manager) + + +@pytest.fixture +def pr_id(): + """Generate unique privacy request ID.""" + return f"test-pr-{uuid.uuid4()}" + + +@pytest.fixture +def dsr_id(): + """Alias for pr_id used by migration tests.""" + return f"test-pr-{uuid.uuid4()}" + @pytest.fixture(scope="session") def test_client(): diff --git a/tests/common/cache/mock_redis.py b/tests/common/cache/mock_redis.py new file mode 100644 index 00000000000..17771fb6c75 --- /dev/null +++ b/tests/common/cache/mock_redis.py @@ -0,0 +1,189 @@ +""" +Autospec'd Redis mock with in-memory backing store for cache tests. + +Uses ``create_autospec(redis.Redis)`` so that: +- Method signatures are validated against the real Redis client +- Missing methods surface as clear errors, not silent misbehavior +- New Redis methods used in production code are auto-available +""" + +import fnmatch +from typing import Any +from unittest.mock import MagicMock, create_autospec + +import redis as redis_lib + +__all__ = ["create_mock_redis"] + + +def create_mock_redis() -> MagicMock: + """ + Create an autospec'd ``redis.Redis`` mock with in-memory state. + + The mock validates method signatures against real ``redis.Redis``, + while providing stateful in-memory behavior via side_effects. + + Internal state is accessible for test assertions:: + + mock._data -- dict of string keys to values + mock._sets -- dict of set keys to set[str] + mock._ttls -- dict of keys to TTL seconds + """ + mock = create_autospec(redis_lib.Redis, instance=True) + + _data: dict[str, Any] = {} + _sets: dict[str, set[str]] = {} + _ttls: dict[str, int] = {} + + # Expose state for test assertions + mock._data = _data + mock._sets = _sets + mock._ttls = _ttls + + # --- Core Redis methods --- + + def _get(name): + return _data.get(name) + + def _set(name, value, ex=None, **kwargs): + _data[name] = value + if ex is not None: + _ttls[name] = ex + return True + + def _setex(name, time, value): + _data[name] = value + _ttls[name] = time + return True + + def _delete(*names): + count = 0 + for n in names: + if n in _data: + del _data[n] + count += 1 + if n in _sets: + del _sets[n] + count += 1 + _ttls.pop(n, None) + return count + + def _exists(*names): + return sum(1 for n in names if n in _data or n in _sets) + + def _sadd(name, *values): + _sets.setdefault(name, set()).update(values) + return len(values) + + def _srem(name, *values): + if name in _sets: + for v in values: + _sets[name].discard(v) + if not _sets[name]: + del _sets[name] + return len(values) + + def _smembers(name): + return _sets.get(name, set()).copy() + + def _keys(pattern="*"): + all_keys = set(_data) | set(_sets) + return [k for k in all_keys if fnmatch.fnmatch(k, pattern)] + + def _scan_iter(match="*", count=None, **kwargs): + return iter(_keys(match)) + + def _ttl_fn(name): + if name not in _data and name not in _sets: + return -2 + return _ttls.get(name, -1) + + def _expire(name, time): + if name in _data or name in _sets: + _ttls[name] = time + return True + return False + + def _ping(**kwargs): + return True + + mock.get.side_effect = _get + mock.set.side_effect = _set + mock.setex.side_effect = _setex + mock.delete.side_effect = _delete + mock.exists.side_effect = _exists + mock.sadd.side_effect = _sadd + mock.srem.side_effect = _srem + mock.smembers.side_effect = _smembers + mock.keys.side_effect = _keys + mock.scan_iter.side_effect = _scan_iter + mock.ttl.side_effect = _ttl_fn + mock.expire.side_effect = _expire + mock.ping.side_effect = _ping + + # --- Pipeline --- + + def _make_pipeline(**kwargs): + pipe = MagicMock() + commands: list = [] + + def pipe_set(name, value, ex=None, **kw): + commands.append(("set", name, value, ex)) + return pipe + + def pipe_sadd(name, *values): + commands.append(("sadd", name, values)) + return pipe + + def pipe_delete(*names): + commands.append(("delete", names)) + return pipe + + def pipe_srem(name, *values): + commands.append(("srem", name, values)) + return pipe + + def pipe_execute(**kw): + results = [] + for cmd in commands: + if cmd[0] == "set": + _data[cmd[1]] = cmd[2] + if cmd[3] is not None: + _ttls[cmd[1]] = cmd[3] + results.append(True) + elif cmd[0] == "sadd": + _sets.setdefault(cmd[1], set()).update(cmd[2]) + results.append(len(cmd[2])) + elif cmd[0] == "delete": + for k in cmd[1]: + _data.pop(k, None) + _sets.pop(k, None) + _ttls.pop(k, None) + results.append(len(cmd[1])) + elif cmd[0] == "srem": + for v in cmd[2]: + if cmd[1] in _sets: + _sets[cmd[1]].discard(v) + results.append(len(cmd[2])) + commands.clear() + return results + + pipe.set.side_effect = pipe_set + pipe.sadd.side_effect = pipe_sadd + pipe.delete.side_effect = pipe_delete + pipe.srem.side_effect = pipe_srem + pipe.execute.side_effect = pipe_execute + return pipe + + mock.pipeline.side_effect = _make_pipeline + + # --- FidesopsRedis-specific methods (used in production compatibility tests) --- + + mock.set_with_autoexpire = MagicMock( + side_effect=lambda key, value, ex=None: _set(key, value, ex=ex) + ) + mock.get_keys_by_prefix = MagicMock( + side_effect=lambda prefix: [k for k in _keys() if k.startswith(prefix)] + ) + + return mock diff --git a/tests/common/cache/test_dsr_store.py b/tests/common/cache/test_dsr_store.py index 228560cb9d2..aa9890ad022 100644 --- a/tests/common/cache/test_dsr_store.py +++ b/tests/common/cache/test_dsr_store.py @@ -1,113 +1,14 @@ """ -Tests for DSRCacheStore using an in-memory RedisCacheManager (dict + set). +Tests for DSRCacheStore using an autospec'd Redis mock. No real Redis required. """ -import fnmatch -from typing import Any, Dict, List, Optional, Set, Union - import pytest from fides.common.cache.dsr_store import DSRCacheStore +from fides.common.cache.manager import RedisCacheManager -RedisValue = Union[bytes, float, int, str] - - -class InMemoryRedis: - """Minimal Redis-like interface: get, set, delete, keys (glob pattern).""" - - def __init__(self) -> None: - self._data: Dict[str, RedisValue] = {} - - def get(self, key: str) -> Optional[Union[str, bytes]]: - val = self._data.get(key) - if val is None: - return None - return val if isinstance(val, (str, bytes)) else str(val) - - def set( - self, - key: str, - value: RedisValue, - ex: Optional[int] = None, - ) -> Optional[bool]: - self._data[key] = value - return True - - def delete(self, *keys: str) -> int: - """Remove keys; returns count removed (redis-py compatible).""" - removed = 0 - for key in keys: - if key in self._data: - del self._data[key] - removed += 1 - return removed - - def keys(self, pattern: str) -> List[str]: - """Glob-style: * matches any number of chars.""" - return [k for k in self._data if fnmatch.fnmatch(k, pattern)] - - def scan_iter(self, match: str = "*", count: Optional[int] = None): - """SCAN-compatible iterator; yields keys matching pattern (count ignored in-memory).""" - return iter(self.keys(match)) - - -class InMemoryRedisCacheManager: - """ - In-memory implementation of the RedisCacheManager interface: a dict for - key -> value and a dict of index_prefix -> set of keys for set_with_index. - """ - - def __init__(self) -> None: - self._redis = InMemoryRedis() - self._index: Dict[str, Set[str]] = {} - - def add_key_to_index(self, index_prefix: str, key: str) -> None: - self._index.setdefault(index_prefix, set()).add(key) - - def remove_key_from_index(self, index_prefix: str, key: str) -> None: - s = self._index.get(index_prefix) - if s is not None: - s.discard(key) - - def get_keys_by_index(self, index_prefix: str) -> List[str]: - return list(self._index.get(index_prefix, set())) - - def delete_index(self, index_prefix: str) -> None: - self._index.pop(index_prefix, None) - - def set_with_index( - self, - key: str, - value: RedisValue, - index_prefix: str, - expire_seconds: Optional[int] = None, - ) -> Optional[bool]: - result = self._redis.set(key, value, ex=expire_seconds) - self.add_key_to_index(index_prefix, key) - return result - - def delete_key_and_remove_from_index( - self, - key: str, - index_prefix: str, - ) -> None: - self._redis.delete(key) - self.remove_key_from_index(index_prefix, key) - - @property - def redis(self) -> InMemoryRedis: - return self._redis - - -@pytest.fixture -def in_memory_manager() -> InMemoryRedisCacheManager: - return InMemoryRedisCacheManager() - - -@pytest.fixture -def dsr_store(in_memory_manager: InMemoryRedisCacheManager) -> DSRCacheStore: - return DSRCacheStore(in_memory_manager) +_TTL = 3600 # Test TTL @pytest.mark.unit @@ -115,131 +16,117 @@ class TestDSRCacheStoreWithInMemoryManager: """DSRCacheStore behavior with an in-memory RedisCacheManager.""" def test_set_and_get(self, dsr_store: DSRCacheStore) -> None: - dsr_store.set("pr-1", "identity:email", "user@example.com") - assert dsr_store.get("pr-1", "identity:email") == "user@example.com" + dsr_store.set("identity:email", "user@example.com", _TTL) + assert dsr_store.get("identity:email") == "user@example.com" def test_get_missing_returns_none(self, dsr_store: DSRCacheStore) -> None: - assert dsr_store.get("pr-1", "identity:email") is None + assert dsr_store.get("identity:email") is None def test_set_with_index_registers_key_in_index( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager + self, dsr_store: DSRCacheStore, mock_redis ) -> None: - dsr_store.set("pr-1", "custom_field:foo", "bar") - keys = in_memory_manager.get_keys_by_index("dsr:pr-1") + dsr_store.set("custom_field:foo", "bar", _TTL) + keys = mock_redis.smembers("__idx:dsr:pr-1") assert "dsr:pr-1:custom_field:foo" in keys assert len(keys) == 1 - def test_get_all_keys_returns_indexed_keys( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager - ) -> None: - dsr_store.write_custom_field("pr-1", "f1", "v1") - dsr_store.write_identity("pr-1", "email", "e@x.com") - keys = dsr_store.get_all_keys("pr-1") + def test_get_all_keys_returns_indexed_keys(self, dsr_store: DSRCacheStore) -> None: + dsr_store.write_custom_field("f1", "v1", _TTL) + dsr_store.write_identity("email", "e@x.com", _TTL) + keys = dsr_store.get_all_keys() assert set(keys) == { "dsr:pr-1:custom_field:f1", "dsr:pr-1:identity:email", } - def test_clear_removes_all_keys_and_index( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager - ) -> None: - dsr_store.write_custom_field("pr-1", "f1", "v1") - dsr_store.write_identity("pr-1", "email", "e@x.com") - dsr_store.clear("pr-1") - assert dsr_store.get_all_keys("pr-1") == [] - assert dsr_store.get("pr-1", "custom_field:f1") is None - assert dsr_store.get("pr-1", "identity:email") is None - - def test_delete_removes_key_and_index_entry( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager - ) -> None: - dsr_store.set("pr-1", "identity:email", "e@x.com") - dsr_store.delete("pr-1", "identity:email") - assert dsr_store.get("pr-1", "identity:email") is None - assert "dsr:pr-1:identity:email" not in dsr_store.get_all_keys("pr-1") + def test_clear_removes_all_keys_and_index(self, dsr_store: DSRCacheStore) -> None: + dsr_store.write_custom_field("f1", "v1", _TTL) + dsr_store.write_identity("email", "e@x.com", _TTL) + dsr_store.clear() + assert dsr_store.get_all_keys() == [] + assert dsr_store.get("custom_field:f1") is None + assert dsr_store.get("identity:email") is None + + def test_delete_removes_key_and_index_entry(self, dsr_store: DSRCacheStore) -> None: + dsr_store.set("identity:email", "e@x.com", _TTL) + dsr_store.delete("identity:email") + assert dsr_store.get("identity:email") is None + assert "dsr:pr-1:identity:email" not in dsr_store.get_all_keys() def test_get_with_legacy_reads_new_key_first( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager + self, dsr_store: DSRCacheStore ) -> None: - dsr_store.write_identity("pr-1", "email", "new@example.com") + dsr_store.write_identity("email", "new@example.com", _TTL) # Legacy key not set; should still get from new key - assert dsr_store.get_identity("pr-1", "email") == "new@example.com" + assert dsr_store.get_identity("email") == "new@example.com" def test_get_with_legacy_migrates_from_legacy_key( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager + self, dsr_store: DSRCacheStore, mock_redis ) -> None: # Simulate legacy data only (no new key) - in_memory_manager.redis.set("id-pr-1-identity-email", "legacy@example.com") - result = dsr_store.get_identity("pr-1", "email") + mock_redis.set("id-pr-1-identity-email", "legacy@example.com") + result = dsr_store.get_identity("email") assert result == "legacy@example.com" # After migrate: new key should exist and legacy should be gone - assert dsr_store.get("pr-1", "identity:email") == "legacy@example.com" - assert in_memory_manager.redis.get("id-pr-1-identity-email") is None + assert dsr_store.get("identity:email") == "legacy@example.com" + assert mock_redis.get("id-pr-1-identity-email") is None def test_write_custom_field_and_get_custom_field( self, dsr_store: DSRCacheStore ) -> None: - dsr_store.write_custom_field("pr-1", "my_field", "my_value") - assert dsr_store.get_custom_field("pr-1", "my_field") == "my_value" + dsr_store.write_custom_field("my_field", "my_value", _TTL) + assert dsr_store.get_custom_field("my_field") == "my_value" def test_convenience_async_execution(self, dsr_store: DSRCacheStore) -> None: - dsr_store.write_async_execution("pr-1", "celery-task-id-xyz") - assert dsr_store.get_async_execution("pr-1") == "celery-task-id-xyz" + dsr_store.write_async_execution("celery-task-id-xyz", _TTL) + assert dsr_store.get_async_execution() == "celery-task-id-xyz" - def test_retry_count( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager - ) -> None: + def test_retry_count(self, dsr_store: DSRCacheStore, mock_redis, manager) -> None: """Mirrors cache.py get/increment/reset_privacy_request_retry_count.""" - assert dsr_store.get_retry_count("pr-1") is None - dsr_store.write_retry_count("pr-1", "3", expire_seconds=86400) - assert dsr_store.get_retry_count("pr-1") == "3" - dsr_store.delete("pr-1", "retry_count") - assert dsr_store.get_retry_count("pr-1") is None - # Legacy key migration - in_memory_manager.redis.set("id-pr-2-privacy-request-retry-count", "1") - assert dsr_store.get_retry_count("pr-2") == "1" - assert ( - in_memory_manager.redis.get("id-pr-2-privacy-request-retry-count") is None - ) - - def test_drp( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager - ) -> None: + assert dsr_store.get_retry_count() is None + dsr_store.write_retry_count("3", expire_seconds=86400) + assert dsr_store.get_retry_count() == "3" + dsr_store.delete("retry_count") + assert dsr_store.get_retry_count() is None + # Legacy key migration (different DSR) + mock_redis.set("id-pr-2-privacy-request-retry-count", "1") + store2 = DSRCacheStore("pr-2", manager) + assert store2.get_retry_count() == "1" + assert mock_redis.get("id-pr-2-privacy-request-retry-count") is None + + def test_drp(self, dsr_store: DSRCacheStore, mock_redis, manager) -> None: """Mirrors privacy_request.py DRP body cache (get_drp_request_body_cache_key).""" - dsr_store.write_drp("pr-1", "address", "encrypted-body", expire_seconds=300) - assert dsr_store.get_drp("pr-1", "address") == "encrypted-body" - assert dsr_store.get_drp("pr-1", "email") is None - # Legacy key migration - in_memory_manager.redis.set("id-pr-2-drp-email", "legacy-drp") - assert dsr_store.get_drp("pr-2", "email") == "legacy-drp" - assert in_memory_manager.redis.get("id-pr-2-drp-email") is None - - def test_encryption( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager - ) -> None: + dsr_store.write_drp("address", "encrypted-body", expire_seconds=300) + assert dsr_store.get_drp("address") == "encrypted-body" + assert dsr_store.get_drp("email") is None + # Legacy key migration (different DSR) + mock_redis.set("id-pr-2-drp-email", "legacy-drp") + store2 = DSRCacheStore("pr-2", manager) + assert store2.get_drp("email") == "legacy-drp" + assert mock_redis.get("id-pr-2-drp-email") is None + + def test_encryption(self, dsr_store: DSRCacheStore, mock_redis, manager) -> None: """Mirrors privacy_request.py / encryption_utils.py encryption key cache.""" - dsr_store.write_encryption("pr-1", "key", "enc-key-123", expire_seconds=3600) - assert dsr_store.get_encryption("pr-1", "key") == "enc-key-123" - assert dsr_store.get_encryption("pr-1", "other") is None - # Legacy key migration - in_memory_manager.redis.set("id-pr-2-encryption-key", "legacy-enc") - assert dsr_store.get_encryption("pr-2", "key") == "legacy-enc" - assert in_memory_manager.redis.get("id-pr-2-encryption-key") is None + dsr_store.write_encryption("key", "enc-key-123", expire_seconds=3600) + assert dsr_store.get_encryption("key") == "enc-key-123" + assert dsr_store.get_encryption("other") is None + # Legacy key migration (different DSR) + mock_redis.set("id-pr-2-encryption-key", "legacy-enc") + store2 = DSRCacheStore("pr-2", manager) + assert store2.get_encryption("key") == "legacy-enc" + assert mock_redis.get("id-pr-2-encryption-key") is None def test_masking_secret( - self, dsr_store: DSRCacheStore, in_memory_manager: InMemoryRedisCacheManager + self, dsr_store: DSRCacheStore, mock_redis, manager ) -> None: """Mirrors secrets_util.get_masking_secret cache read (and write path).""" dsr_store.write_masking_secret( - "pr-1", "hash", "salt", "encoded-secret", expire_seconds=600 - ) - assert dsr_store.get_masking_secret("pr-1", "hash", "salt") == "encoded-secret" - assert dsr_store.get_masking_secret("pr-1", "hash", "other") is None - # Legacy key migration - in_memory_manager.redis.set( - "id-pr-2-masking-secret-hash-pepper", "legacy-masking" - ) - assert ( - dsr_store.get_masking_secret("pr-2", "hash", "pepper") == "legacy-masking" + "hash", "salt", "encoded-secret", expire_seconds=600 ) - assert in_memory_manager.redis.get("id-pr-2-masking-secret-hash-pepper") is None + assert dsr_store.get_masking_secret("hash", "salt") == "encoded-secret" + assert dsr_store.get_masking_secret("hash", "other") is None + # Legacy key migration (different DSR) + mock_redis.set("id-pr-2-masking-secret-hash-pepper", "legacy-masking") + store2 = DSRCacheStore("pr-2", manager) + assert store2.get_masking_secret("hash", "pepper") == "legacy-masking" + assert mock_redis.get("id-pr-2-masking-secret-hash-pepper") is None diff --git a/tests/common/cache/test_dsr_store_clear_integration.py b/tests/common/cache/test_dsr_store_clear_integration.py new file mode 100644 index 00000000000..090eaeb9b20 --- /dev/null +++ b/tests/common/cache/test_dsr_store_clear_integration.py @@ -0,0 +1,106 @@ +""" +Tests for privacy_request.clear_cached_values() integration with DSRCacheStore. + +Verifies that clearing uses the store and handles both legacy and new cache keys. +""" + +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from fides.api.models.privacy_request.privacy_request import PrivacyRequest +from fides.common.cache.dsr_store import DSRCacheStore +from fides.common.cache.manager import RedisCacheManager +from tests.common.cache.mock_redis import create_mock_redis + +_TTL = 3600 # Test TTL + + +@pytest.mark.unit +class TestPrivacyRequestClearCachedValues: + """Test clear_cached_values() with DSR store.""" + + def test_clear_removes_legacy_keys(self): + """clear_cached_values removes legacy cache keys.""" + mock_redis = create_mock_redis() + pr_id = f"test-pr-{uuid.uuid4()}" + + # Simulate legacy cached data + mock_redis.set(f"id-{pr_id}-identity-email", "test@example.com") + mock_redis.set(f"id-{pr_id}-identity-phone_number", "+1234567890") + mock_redis.set(f"id-{pr_id}-encryption-key", "encryption-key") + + # Mock privacy request + pr = MagicMock() + pr.id = pr_id + + # Patch get_cache in the api.util.cache module where get_dsr_cache_store calls it + with patch("fides.api.util.cache.get_cache", return_value=mock_redis): + # Import here to avoid app initialization + PrivacyRequest.clear_cached_values(pr) + + # Verify all keys deleted + assert len(mock_redis.keys(f"*{pr_id}*")) == 0 + + def test_clear_removes_new_keys(self): + """clear_cached_values removes new-format cache keys.""" + mock_redis = create_mock_redis() + pr_id = f"test-pr-{uuid.uuid4()}" + + # Simulate new cached data via store + manager = RedisCacheManager(mock_redis) + store = DSRCacheStore(pr_id, manager) + store.write_identity("email", "test@example.com", _TTL) + store.write_encryption("key", "encryption-key", _TTL) + + pr = MagicMock() + pr.id = pr_id + + with patch("fides.api.util.cache.get_cache", return_value=mock_redis): + PrivacyRequest.clear_cached_values(pr) + + assert len(mock_redis.keys(f"*{pr_id}*")) == 0 + + def test_clear_removes_mixed_keys(self): + """clear_cached_values removes both legacy and new keys.""" + mock_redis = create_mock_redis() + pr_id = f"test-pr-{uuid.uuid4()}" + + # Mixed: legacy identity, new encryption + mock_redis.set(f"id-{pr_id}-identity-email", "legacy@example.com") + mock_redis.set(f"id-{pr_id}-custom-privacy-request-field-dept", "Engineering") + + manager = RedisCacheManager(mock_redis) + store = DSRCacheStore(pr_id, manager) + store.write_encryption("key", "new-encryption-key", _TTL) + store.write_async_execution("task-123", _TTL) + + pr = MagicMock() + pr.id = pr_id + + with patch("fides.api.util.cache.get_cache", return_value=mock_redis): + PrivacyRequest.clear_cached_values(pr) + + assert len(mock_redis.keys(f"*{pr_id}*")) == 0 + + def test_clear_removes_index(self): + """clear_cached_values removes the DSR index.""" + mock_redis = create_mock_redis() + pr_id = f"test-pr-{uuid.uuid4()}" + + manager = RedisCacheManager(mock_redis) + store = DSRCacheStore(pr_id, manager) + store.write_identity("email", "test@example.com", _TTL) + + # Verify index exists + assert len(mock_redis.smembers(f"__idx:dsr:{pr_id}")) > 0 + + pr = MagicMock() + pr.id = pr_id + + with patch("fides.api.util.cache.get_cache", return_value=mock_redis): + PrivacyRequest.clear_cached_values(pr) + + # Index should be deleted + assert len(mock_redis.smembers(f"__idx:dsr:{pr_id}")) == 0 diff --git a/tests/common/cache/test_dsr_store_custom_fields_integration.py b/tests/common/cache/test_dsr_store_custom_fields_integration.py new file mode 100644 index 00000000000..7eba5994125 --- /dev/null +++ b/tests/common/cache/test_dsr_store_custom_fields_integration.py @@ -0,0 +1,138 @@ +""" +Tests for custom fields and encryption cache operations in DSRCacheStore. + +Tests the service layer directly with MockRedis - no patching needed. +""" + +import json + +import pytest + +from fides.common.cache.dsr_store import DSRCacheStore + +# Mark all tests as unit tests +pytestmark = pytest.mark.unit + +_TTL = 3600 # Test TTL + + +class TestDSRCacheStoreCustomFields: + """Test custom fields cache operations in DSRCacheStore.""" + + def test_cache_custom_fields_writes_all_fields(self, manager, mock_redis, pr_id): + """cache_custom_fields writes all fields to new-format keys.""" + store = DSRCacheStore(pr_id, manager) + custom_fields = { + "department": json.dumps("Engineering"), + "employee_id": json.dumps("E12345"), + } + + store.cache_custom_fields(custom_fields, expire_seconds=3600) + + # All keys written in new format + assert mock_redis.get(f"dsr:{pr_id}:custom_field:department") == json.dumps( + "Engineering" + ) + assert mock_redis.get(f"dsr:{pr_id}:custom_field:employee_id") == json.dumps( + "E12345" + ) + + # Legacy keys do NOT exist + assert ( + mock_redis.get(f"id-{pr_id}-custom-privacy-request-field-department") + is None + ) + + def test_get_cached_custom_fields_reads_all_fields(self, manager, pr_id): + """get_cached_custom_fields reads all fields from new-format keys.""" + store = DSRCacheStore(pr_id, manager) + custom_fields = { + "department": json.dumps("Engineering"), + "employee_id": json.dumps("E12345"), + } + store.cache_custom_fields(custom_fields, _TTL) + + result = store.get_cached_custom_fields() + + assert result["department"] == json.dumps("Engineering") + assert result["employee_id"] == json.dumps("E12345") + + def test_get_cached_custom_fields_migrates_legacy_keys( + self, manager, mock_redis, pr_id + ): + """get_cached_custom_fields reads and migrates legacy keys on first access.""" + store = DSRCacheStore(pr_id, manager) + # Write legacy format + mock_redis.set( + f"id-{pr_id}-custom-privacy-request-field-department", + json.dumps("Engineering"), + ) + mock_redis.set( + f"id-{pr_id}-custom-privacy-request-field-employee_id", json.dumps("E12345") + ) + + result = store.get_cached_custom_fields() + + # Values are returned correctly + assert result["department"] == json.dumps("Engineering") + assert result["employee_id"] == json.dumps("E12345") + + # Legacy keys migrated to new format + assert mock_redis.get(f"dsr:{pr_id}:custom_field:department") is not None + assert ( + mock_redis.get(f"id-{pr_id}-custom-privacy-request-field-department") + is None + ) + + def test_has_cached_custom_fields_detects_both_formats( + self, manager, mock_redis, pr_id + ): + """has_cached_custom_fields detects fields in both legacy and new formats.""" + store = DSRCacheStore(pr_id, manager) + # Empty initially + assert store.has_cached_custom_fields() is False + + # Add legacy key + mock_redis.set( + f"id-{pr_id}-custom-privacy-request-field-department", + json.dumps("Engineering"), + ) + assert store.has_cached_custom_fields() is True + + # Clear and test new format + store.clear() + store.write_custom_field("department", json.dumps("Engineering"), _TTL) + assert store.has_cached_custom_fields() is True + + +class TestDSRCacheStoreEncryption: + """Test encryption key cache operations in DSRCacheStore.""" + + def test_write_encryption_writes_key(self, manager, mock_redis, pr_id): + """write_encryption writes encryption key to new-format key.""" + store = DSRCacheStore(pr_id, manager) + store.write_encryption("key", "test-encryption-key-12345", expire_seconds=3600) + + assert ( + mock_redis.get(f"dsr:{pr_id}:encryption:key") == "test-encryption-key-12345" + ) + + # Legacy key does NOT exist + assert mock_redis.get(f"id-{pr_id}-encryption-key") is None + + def test_get_encryption_migrates_legacy_key(self, manager, mock_redis, pr_id): + """get_encryption reads and migrates legacy encryption keys.""" + store = DSRCacheStore(pr_id, manager) + # Write legacy format + mock_redis.set(f"id-{pr_id}-encryption-key", "test-encryption-key-12345") + + # Read via store + value = store.get_encryption("key") + + assert value == "test-encryption-key-12345" + + # Legacy key migrated + assert ( + mock_redis.get(f"dsr:{pr_id}:encryption:key") == "test-encryption-key-12345" + ) + assert mock_redis.get(f"id-{pr_id}-encryption-key") is None diff --git a/tests/common/cache/test_dsr_store_drp_integration.py b/tests/common/cache/test_dsr_store_drp_integration.py new file mode 100644 index 00000000000..0544bdb5411 --- /dev/null +++ b/tests/common/cache/test_dsr_store_drp_integration.py @@ -0,0 +1,148 @@ +""" +Tests for DSRCacheStore DRP request body caching. + +Focuses on service-layer methods for DRP data management, including: +- Writing DRP fields in new format +- Reading DRP fields from both new and legacy formats +- Automatic migration on read +""" + +import pytest + +from fides.common.cache.dsr_store import DSRCacheStore + +# Mark all tests as unit tests +pytestmark = pytest.mark.unit + +_TTL = 3600 # Test TTL + + +class TestDSRCacheStoreDRP: + """Test DSRCacheStore DRP request body methods.""" + + def test_cache_drp_request_body_writes_all_fields(self, manager, pr_id): + """cache_drp_request_body writes all fields to new-format keys.""" + store = DSRCacheStore(pr_id, manager) + drp_body = { + "meta": "metadata_value", + "regime": "gdpr", + "exercise": "access", + "identity": '{"email": "user@example.com"}', + } + + store.cache_drp_request_body(drp_body, expire_seconds=3600) + + # Verify all fields written to new format + assert store.get_drp("meta") == "metadata_value" + assert store.get_drp("regime") == "gdpr" + assert store.get_drp("exercise") == "access" + assert store.get_drp("identity") == '{"email": "user@example.com"}' + + def test_cache_drp_request_body_skips_none_values(self, manager, pr_id): + """cache_drp_request_body skips None values.""" + store = DSRCacheStore(pr_id, manager) + drp_body = { + "meta": "metadata_value", + "regime": None, + "exercise": "access", + } + + store.cache_drp_request_body(drp_body, _TTL) + + # Only non-None fields should be written + assert store.get_drp("meta") == "metadata_value" + assert store.get_drp("regime") is None + assert store.get_drp("exercise") == "access" + + def test_get_cached_drp_request_body_reads_all_fields(self, manager, pr_id): + """get_cached_drp_request_body reads all fields from new-format keys.""" + store = DSRCacheStore(pr_id, manager) + drp_body = { + "meta": "metadata_value", + "regime": "gdpr", + "exercise": "access", + } + store.cache_drp_request_body(drp_body, _TTL) + + result = store.get_cached_drp_request_body() + + assert result == { + "meta": "metadata_value", + "regime": "gdpr", + "exercise": "access", + } + + def test_get_cached_drp_request_body_migrates_legacy_keys( + self, manager, mock_redis, pr_id + ): + """get_cached_drp_request_body reads and migrates legacy keys on first access.""" + store = DSRCacheStore(pr_id, manager) + # Write legacy format directly + mock_redis.set(f"id-{pr_id}-drp-meta", "legacy_metadata") + mock_redis.set(f"id-{pr_id}-drp-regime", "ccpa") + + result = store.get_cached_drp_request_body() + + assert result == { + "meta": "legacy_metadata", + "regime": "ccpa", + } + + # Verify migration happened (new keys exist, legacy keys deleted) + assert mock_redis.get(f"dsr:{pr_id}:drp:meta") == "legacy_metadata" + assert mock_redis.get(f"dsr:{pr_id}:drp:regime") == "ccpa" + assert mock_redis.get(f"id-{pr_id}-drp-meta") is None + assert mock_redis.get(f"id-{pr_id}-drp-regime") is None + + def test_has_cached_drp_request_body_detects_both_formats( + self, manager, mock_redis, pr_id + ): + """has_cached_drp_request_body detects DRP data in both legacy and new formats.""" + store = DSRCacheStore(pr_id, manager) + # Empty initially + assert store.has_cached_drp_request_body() is False + + # Write new format + store.write_drp("meta", "metadata", _TTL) + assert store.has_cached_drp_request_body() is True + + # Clear and test legacy format + store.clear() + assert store.has_cached_drp_request_body() is False + + mock_redis.set(f"id-{pr_id}-drp-regime", "gdpr") + assert store.has_cached_drp_request_body() is True + + def test_get_cached_drp_request_body_returns_empty_dict_when_no_data( + self, manager, pr_id + ): + """get_cached_drp_request_body returns empty dict when no DRP data cached.""" + store = DSRCacheStore(pr_id, manager) + result = store.get_cached_drp_request_body() + assert result == {} + + def test_drp_migration_then_new_writes(self, manager, mock_redis, pr_id): + """After migrating legacy keys, new writes use indexed format.""" + store = DSRCacheStore(pr_id, manager) + # Start with legacy keys + mock_redis.set(f"id-{pr_id}-drp-meta", "legacy_metadata") + + # Read triggers migration + result1 = store.get_cached_drp_request_body() + assert result1["meta"] == "legacy_metadata" + + # Now write new fields - should use indexed format + store.write_drp("regime", "gdpr", _TTL) + store.write_drp("exercise", "access", _TTL) + + # Read all - should get both migrated and new + result2 = store.get_cached_drp_request_body() + assert result2["meta"] == "legacy_metadata" + assert result2["regime"] == "gdpr" + assert result2["exercise"] == "access" + + # Verify all keys are now indexed + all_keys = store.get_all_keys() + assert f"dsr:{pr_id}:drp:meta" in all_keys + assert f"dsr:{pr_id}:drp:regime" in all_keys + assert f"dsr:{pr_id}:drp:exercise" in all_keys diff --git a/tests/common/cache/test_dsr_store_identity_integration.py b/tests/common/cache/test_dsr_store_identity_integration.py new file mode 100644 index 00000000000..12e549acee2 --- /dev/null +++ b/tests/common/cache/test_dsr_store_identity_integration.py @@ -0,0 +1,103 @@ +""" +Tests for identity cache operations in DSRCacheStore. + +Tests the service layer directly with MockRedis - no patching needed. +""" + +import json + +import pytest + +from fides.common.cache.dsr_store import DSRCacheStore + + +@pytest.fixture +def identity_data(): + """Sample identity data for tests.""" + return { + "email": "user@example.com", + "phone_number": "+1234567890", + } + + +# Mark all tests as unit tests +pytestmark = pytest.mark.unit + +_TTL = 3600 # Test TTL + + +class TestDSRCacheStoreIdentity: + """Test identity cache operations in DSRCacheStore.""" + + def test_cache_identity_data_writes_all_attributes( + self, manager, mock_redis, pr_id + ): + """cache_identity_data writes all identity attributes to new-format keys.""" + store = DSRCacheStore(pr_id, manager) + identity_data = { + "email": json.dumps("user@example.com"), + "phone_number": json.dumps("+1234567890"), + } + + store.cache_identity_data(identity_data, expire_seconds=3600) + + # All keys written in new format + assert mock_redis.get(f"dsr:{pr_id}:identity:email") == json.dumps( + "user@example.com" + ) + assert mock_redis.get(f"dsr:{pr_id}:identity:phone_number") == json.dumps( + "+1234567890" + ) + + # Legacy keys do NOT exist + assert mock_redis.get(f"id-{pr_id}-identity-email") is None + + def test_get_cached_identity_data_reads_all_attributes( + self, manager, pr_id, identity_data + ): + """get_cached_identity_data reads all identity attributes from new-format keys.""" + store = DSRCacheStore(pr_id, manager) + # Write via store + encoded_data = {k: json.dumps(v) for k, v in identity_data.items()} + store.cache_identity_data(encoded_data, _TTL) + + result = store.get_cached_identity_data() + + assert result["email"] == json.dumps("user@example.com") + assert result["phone_number"] == json.dumps("+1234567890") + + def test_get_cached_identity_data_migrates_legacy_keys( + self, manager, mock_redis, pr_id, identity_data + ): + """get_cached_identity_data reads and migrates legacy keys on first access.""" + store = DSRCacheStore(pr_id, manager) + # Write legacy format with JSON encoding + for key, value in identity_data.items(): + mock_redis.set(f"id-{pr_id}-identity-{key}", json.dumps(value)) + + result = store.get_cached_identity_data() + + # Values are returned correctly + assert result["email"] == json.dumps("user@example.com") + assert result["phone_number"] == json.dumps("+1234567890") + + # Legacy keys migrated to new format + assert mock_redis.get(f"dsr:{pr_id}:identity:email") is not None + assert mock_redis.get(f"id-{pr_id}-identity-email") is None + + def test_has_cached_identity_data_detects_both_formats( + self, manager, mock_redis, pr_id + ): + """has_cached_identity_data detects identity data in both legacy and new formats.""" + store = DSRCacheStore(pr_id, manager) + # Empty initially + assert store.has_cached_identity_data() is False + + # Add legacy key + mock_redis.set(f"id-{pr_id}-identity-email", json.dumps("test@example.com")) + assert store.has_cached_identity_data() is True + + # Clear and test new format + store.clear() + store.write_identity("email", json.dumps("test@example.com"), _TTL) + assert store.has_cached_identity_data() is True diff --git a/tests/common/cache/test_dsr_store_migration.py b/tests/common/cache/test_dsr_store_migration.py index 1a8daf8ea72..0088848c276 100644 --- a/tests/common/cache/test_dsr_store_migration.py +++ b/tests/common/cache/test_dsr_store_migration.py @@ -4,104 +4,14 @@ Verifies existing cached data (legacy format) is correctly read, migrated, and cleared. """ -import fnmatch import uuid -from typing import Any, Callable, Dict, List, Optional, Set, Union import pytest from fides.common.cache.dsr_store import DSRCacheStore from fides.common.cache.manager import RedisCacheManager -RedisValue = Union[bytes, float, int, str] - - -class MockPipeline: - """Minimal Redis pipeline: buffers ops and runs them on execute().""" - - def __init__(self, redis: "MockRedis") -> None: - self._redis = redis - self._ops: List[Callable[[], Any]] = [] - - def set( - self, key: str, value: RedisValue, ex: Optional[int] = None - ) -> "MockPipeline": - def op() -> bool: - return self._redis.set(key, value, ex=ex) - - self._ops.append(op) - return self - - def sadd(self, key: str, *members: Union[str, bytes]) -> "MockPipeline": - def op() -> int: - return self._redis.sadd(key, *members) - - self._ops.append(op) - return self - - def delete(self, *keys: str) -> "MockPipeline": - def op() -> int: - return self._redis.delete(*keys) - - self._ops.append(op) - return self - - def srem(self, key: str, *members: Union[str, bytes]) -> "MockPipeline": - def op() -> int: - return self._redis.srem(key, *members) - - self._ops.append(op) - return self - - def execute(self) -> List[Any]: - return [op() for op in self._ops] - - -class MockRedis: - """Mock Redis with minimal interface for DSRCacheStore.""" - - def __init__(self) -> None: - self._data: Dict[str, RedisValue] = {} - self._sets: Dict[str, Set[Union[str, bytes]]] = {} - - def get(self, key: str) -> Optional[Union[str, bytes]]: - val = self._data.get(key) - return val if isinstance(val, (str, bytes)) else str(val) if val else None - - def set(self, key: str, value: RedisValue, ex: Optional[int] = None) -> bool: - self._data[key] = value - return True - - def delete(self, *keys: str) -> int: - deleted = sum( - 1 for k in keys if self._data.pop(k, None) or self._sets.pop(k, None) - ) - return deleted - - def keys(self, pattern: str) -> List[str]: - return [k for k in self._data if fnmatch.fnmatch(k, pattern)] - - def scan_iter(self, match: str = "*", count: Optional[int] = None): - return iter(self.keys(match)) - - def sadd(self, key: str, *members: Union[str, bytes]) -> int: - s = self._sets.setdefault(key, set()) - before = len(s) - s.update(members) - return len(s) - before - - def srem(self, key: str, *members: Union[str, bytes]) -> int: - if key not in self._sets: - return 0 - before = len(self._sets[key]) - self._sets[key].difference_update(members) - return before - len(self._sets[key]) - - def smembers(self, key: str) -> Set[Union[str, bytes]]: - return self._sets.get(key, set()).copy() - - def pipeline(self) -> MockPipeline: - return MockPipeline(self) +_TTL = 3600 # Test TTL # Test data factories @@ -122,21 +32,6 @@ def make_new_key(dsr_id: str, part: str) -> str: return f"dsr:{dsr_id}:{part}" -@pytest.fixture -def mock_redis(): - return MockRedis() - - -@pytest.fixture -def dsr_store(mock_redis): - return DSRCacheStore(RedisCacheManager(mock_redis)) - - -@pytest.fixture -def dsr_id(): - return make_dsr_id() - - @pytest.mark.unit class TestLegacyKeyMigration: """Test legacy key formats are readable and migrated correctly.""" @@ -154,27 +49,29 @@ class TestLegacyKeyMigration: ], ) def test_legacy_keys_readable( - self, mock_redis, dsr_store, dsr_id, field_type, getter, field_key, value + self, mock_redis, manager, dsr_id, field_type, getter, field_key, value ): """Legacy keys are readable via store convenience methods.""" + store = DSRCacheStore(dsr_id, manager) legacy_key = make_legacy_key(dsr_id, field_type, field_key) mock_redis.set(legacy_key, value) # Call appropriate getter if getter == "get_masking_secret": - result = dsr_store.get_masking_secret(dsr_id, "hash", field_key) + result = store.get_masking_secret("hash", field_key) elif field_key: - result = getattr(dsr_store, getter)(dsr_id, field_key) + result = getattr(store, getter)(field_key) else: - result = getattr(dsr_store, getter)(dsr_id) + result = getattr(store, getter)() assert result == value - def test_legacy_key_migrated_on_read(self, mock_redis, dsr_store, dsr_id): + def test_legacy_key_migrated_on_read(self, mock_redis, manager, dsr_id): """Legacy key is migrated to new format on first read.""" + store = DSRCacheStore(dsr_id, manager) mock_redis.set(make_legacy_key(dsr_id, "identity", "email"), "migrate@test.com") - email = dsr_store.get_identity(dsr_id, "email") + email = store.get_identity("email") assert email == "migrate@test.com" # New key exists, legacy deleted, index updated @@ -186,10 +83,11 @@ def test_legacy_key_migrated_on_read(self, mock_redis, dsr_store, dsr_id): f"__idx:dsr:{dsr_id}" ) - def test_new_writes_create_indexed_keys_only(self, mock_redis, dsr_store, dsr_id): + def test_new_writes_create_indexed_keys_only(self, mock_redis, manager, dsr_id): """New writes create new-format keys and index them; no legacy keys written.""" - dsr_store.write_identity(dsr_id, "email", "new@example.com") - dsr_store.write_custom_field(dsr_id, "department", "Sales") + store = DSRCacheStore(dsr_id, manager) + store.write_identity("email", "new@example.com", _TTL) + store.write_custom_field("department", "Sales", _TTL) assert ( mock_redis.get(make_new_key(dsr_id, "identity:email")) == "new@example.com" @@ -205,14 +103,15 @@ def test_new_writes_create_indexed_keys_only(self, mock_redis, dsr_store, dsr_id is None ) - def test_clear_removes_mixed_keys(self, mock_redis, dsr_store, dsr_id): + def test_clear_removes_mixed_keys(self, mock_redis, manager, dsr_id): """clear() removes both legacy and new keys using SCAN.""" + store = DSRCacheStore(dsr_id, manager) mock_redis.set(make_legacy_key(dsr_id, "identity", "email"), "legacy@test.com") mock_redis.set(make_legacy_key(dsr_id, "encryption", "key"), "legacy-key") - dsr_store.write_identity(dsr_id, "phone_number", "+1234567890") - dsr_store.write_custom_field(dsr_id, "department", "Engineering") + store.write_identity("phone_number", "+1234567890", _TTL) + store.write_custom_field("department", "Engineering", _TTL) - dsr_store.clear(dsr_id) + store.clear() assert len(mock_redis.keys(f"*{dsr_id}*")) == 0 @@ -224,9 +123,11 @@ def test_index_backfill(self, mock_redis, dsr_id): ) store = DSRCacheStore( - RedisCacheManager(mock_redis), backfill_index_on_legacy_read=True + dsr_id, + RedisCacheManager(mock_redis), + backfill_index_on_legacy_read=True, ) - keys = store.get_all_keys(dsr_id) + keys = store.get_all_keys() assert len(keys) == 2 assert len(mock_redis.smembers(f"__idx:dsr:{dsr_id}")) == 2 @@ -239,36 +140,41 @@ class TestMultipleRequestIsolation: def test_mixed_dsr_states(self, mock_redis): """Operations on one DSR don't affect others (legacy, new, mixed).""" dsr1, dsr2, dsr3 = make_dsr_id(), make_dsr_id(), make_dsr_id() - store = DSRCacheStore(RedisCacheManager(mock_redis)) + mgr = RedisCacheManager(mock_redis) + store1 = DSRCacheStore(dsr1, mgr) + store2 = DSRCacheStore(dsr2, mgr) + store3 = DSRCacheStore(dsr3, mgr) # DSR1: legacy, DSR2: new, DSR3: mixed mock_redis.set(make_legacy_key(dsr1, "identity", "email"), "dsr1@test.com") - store.write_identity(dsr2, "email", "dsr2@test.com") + store2.write_identity("email", "dsr2@test.com", _TTL) mock_redis.set(make_legacy_key(dsr3, "identity", "email"), "dsr3@test.com") - store.write_identity(dsr3, "phone_number", "+1234567890") + store3.write_identity("phone_number", "+1234567890", _TTL) # Verify all readable - assert store.get_identity(dsr1, "email") == "dsr1@test.com" - assert store.get_identity(dsr2, "email") == "dsr2@test.com" - assert store.get_identity(dsr3, "email") == "dsr3@test.com" - assert store.get_identity(dsr3, "phone_number") == "+1234567890" + assert store1.get_identity("email") == "dsr1@test.com" + assert store2.get_identity("email") == "dsr2@test.com" + assert store3.get_identity("email") == "dsr3@test.com" + assert store3.get_identity("phone_number") == "+1234567890" # Clear DSR2 doesn't affect others - store.clear(dsr2) - assert store.get_identity(dsr1, "email") == "dsr1@test.com" - assert store.get_identity(dsr3, "email") == "dsr3@test.com" - assert store.get_identity(dsr2, "email") is None - assert store.get_all_keys(dsr2) == [] + store2.clear() + assert store1.get_identity("email") == "dsr1@test.com" + assert store3.get_identity("email") == "dsr3@test.com" + assert store2.get_identity("email") is None + assert store2.get_all_keys() == [] def test_clear_isolation(self, mock_redis): """Clearing one DSR doesn't delete another's keys.""" dsr1, dsr2 = make_dsr_id(), make_dsr_id() - store = DSRCacheStore(RedisCacheManager(mock_redis)) + mgr = RedisCacheManager(mock_redis) + store1 = DSRCacheStore(dsr1, mgr) + store2 = DSRCacheStore(dsr2, mgr) - store.write_identity(dsr1, "email", "dsr1@test.com") - store.write_identity(dsr2, "email", "dsr2@test.com") + store1.write_identity("email", "dsr1@test.com", _TTL) + store2.write_identity("email", "dsr2@test.com", _TTL) - store.clear(dsr1) + store1.clear() assert mock_redis.get(make_new_key(dsr1, "identity:email")) is None assert mock_redis.get(make_new_key(dsr2, "identity:email")) == "dsr2@test.com" diff --git a/tests/common/cache/test_dsr_store_production_compatibility.py b/tests/common/cache/test_dsr_store_production_compatibility.py new file mode 100644 index 00000000000..504c72b3d50 --- /dev/null +++ b/tests/common/cache/test_dsr_store_production_compatibility.py @@ -0,0 +1,219 @@ +""" +Production compatibility tests for DSR cache migration. + +These tests simulate a production deployment scenario where: +1. Old code has cached DSR data using legacy key formats (id-{id}-*) +2. New code is deployed and must read/process those in-flight requests +3. New code continues to work correctly with legacy keys + +This validates that the migration won't break production requests that are +already in-flight when the new code is deployed. +""" + +import json +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from fides.api.models.privacy_request import PrivacyRequest +from fides.api.tasks.encryption_utils import encrypt_access_request_results +from fides.api.util.cache import ( + get_cache, + get_custom_privacy_request_field_cache_key, + get_drp_request_body_cache_key, + get_dsr_cache_store, + get_encryption_cache_key, + get_identity_cache_key, +) + +_TTL = 3600 # Test TTL + + +@pytest.mark.unit +class TestInFlightDSRLifecycle: + """ + Simulate a full in-flight DSR that was cached by old code, then processed + and cleared by new code after a deployment. This is the "Steps to Confirm" + scenario from the PR: volume of in-flight DSR processing, then upgrading + in the middle of it. + """ + + def test_full_lifecycle_legacy_request_processed_by_new_code(self, mock_redis): + """ + End-to-end: old code caches a complete DSR (identity, encryption, + custom fields, DRP body). New code reads everything, processes the + request, and clears the cache. + """ + pr_id = f"pri_{uuid.uuid4()}" + + # --- Phase 1: "Old code" caches a full DSR using legacy key format --- + with ( + patch("fides.api.util.cache.get_cache", return_value=mock_redis), + patch("fides.api.util.cache._connection", mock_redis), + ): + cache = get_cache() + + # Identity + cache.set_with_autoexpire( + get_identity_cache_key(pr_id, "email"), json.dumps("user@example.com") + ) + cache.set_with_autoexpire( + get_identity_cache_key(pr_id, "phone_number"), + json.dumps("+1234567890"), + ) + + # Encryption key + cache.set_with_autoexpire( + get_encryption_cache_key(pr_id, "key"), "0123456789abcdef" + ) + + # Custom fields + cache.set_with_autoexpire( + get_custom_privacy_request_field_cache_key(pr_id, "department"), + json.dumps("Engineering"), + ) + cache.set_with_autoexpire( + get_custom_privacy_request_field_cache_key(pr_id, "tenant_id"), + json.dumps("tenant-42"), + ) + + # DRP body + cache.set_with_autoexpire( + get_drp_request_body_cache_key(pr_id, "meta"), "DrpMeta(version='0.5')" + ) + cache.set_with_autoexpire( + get_drp_request_body_cache_key(pr_id, "regime"), "ccpa" + ) + + # Verify legacy keys exist before "deployment" + legacy_keys = [k for k in mock_redis.keys("*") if pr_id in k] + assert len(legacy_keys) == 7 # 2 identity + 1 encryption + 2 custom + 2 DRP + + # --- Phase 2: "New code deployed" — read everything via PrivacyRequest --- + with ( + patch("fides.api.util.cache.get_cache", return_value=mock_redis), + patch("fides.api.util.cache._connection", mock_redis), + ): + pr = MagicMock() + pr.id = pr_id + + # Read identity (triggers migration) + identity_data = PrivacyRequest.get_cached_identity_data(pr) + assert identity_data["email"] == "user@example.com" + assert identity_data["phone_number"] == "+1234567890" + + # Read encryption key (triggers migration) + encryption_key = PrivacyRequest.get_cached_encryption_key(pr) + assert encryption_key == "0123456789abcdef" + + # Encrypt data using the cached key + encrypted = encrypt_access_request_results("sensitive data", pr_id) + assert encrypted != "sensitive data" # Actually encrypted + + # Read custom fields (triggers migration) + store = get_dsr_cache_store(pr_id) + custom_fields = store.get_cached_custom_fields() + assert custom_fields["department"] == json.dumps("Engineering") + assert custom_fields["tenant_id"] == json.dumps("tenant-42") + + # Read DRP body (triggers migration) + drp_body = store.get_cached_drp_request_body() + assert drp_body["meta"] == "DrpMeta(version='0.5')" + assert drp_body["regime"] == "ccpa" + + # --- Phase 3: Verify migration happened --- + # All legacy keys should be gone + remaining_legacy = [ + k for k in mock_redis.keys("*") if k.startswith(f"id-{pr_id}") + ] + assert remaining_legacy == [], ( + f"Legacy keys not migrated: {remaining_legacy}" + ) + + # New-format keys should exist + assert store.get_identity("email") == json.dumps("user@example.com") + assert store.get_encryption("key") == "0123456789abcdef" + + # --- Phase 4: "Request complete" — clear the cache --- + store.clear() + + # Everything gone + all_keys = [k for k in mock_redis.keys("*") if pr_id in k] + assert all_keys == [], f"Keys remain after clear: {all_keys}" + + def test_multiple_in_flight_requests_mixed_formats(self, mock_redis): + """ + Simulate 3 concurrent requests: one fully legacy, one fully new, + one partially migrated. All should be independently readable and + clearable after "deployment". + """ + legacy_id = f"pri_{uuid.uuid4()}" + new_id = f"pri_{uuid.uuid4()}" + mixed_id = f"pri_{uuid.uuid4()}" + + with ( + patch("fides.api.util.cache.get_cache", return_value=mock_redis), + patch("fides.api.util.cache._connection", mock_redis), + ): + cache = get_cache() + + # Request 1: fully legacy + cache.set_with_autoexpire( + get_identity_cache_key(legacy_id, "email"), + json.dumps("legacy@example.com"), + ) + cache.set_with_autoexpire( + get_encryption_cache_key(legacy_id, "key"), "legacy-key-1234567" + ) + + # Request 2: fully new format + store_new = get_dsr_cache_store(new_id) + store_new.write_identity("email", json.dumps("new@example.com"), _TTL) + store_new.write_encryption("key", "new-key-123456789", _TTL) + + # Request 3: mixed (legacy identity, new encryption) + cache.set_with_autoexpire( + get_identity_cache_key(mixed_id, "email"), + json.dumps("mixed@example.com"), + ) + store_mixed = get_dsr_cache_store(mixed_id) + store_mixed.write_encryption("key", "mixed-key-12345678", _TTL) + + # --- "New code deployed" — read all three --- + with ( + patch("fides.api.util.cache.get_cache", return_value=mock_redis), + patch("fides.api.util.cache._connection", mock_redis), + ): + for pr_id, expected_email, expected_key in [ + (legacy_id, "legacy@example.com", "legacy-key-1234567"), + (new_id, "new@example.com", "new-key-123456789"), + (mixed_id, "mixed@example.com", "mixed-key-12345678"), + ]: + pr = MagicMock() + pr.id = pr_id + identity = PrivacyRequest.get_cached_identity_data(pr) + assert identity["email"] == expected_email, f"Failed for {pr_id}" + enc_key = PrivacyRequest.get_cached_encryption_key(pr) + assert enc_key == expected_key, f"Failed for {pr_id}" + + # Clear one, others unaffected + store_legacy = get_dsr_cache_store(legacy_id) + store_legacy.clear() + + pr_new = MagicMock() + pr_new.id = new_id + assert ( + PrivacyRequest.get_cached_identity_data(pr_new)["email"] + == "new@example.com" + ) + + pr_mixed = MagicMock() + pr_mixed.id = mixed_id + assert ( + PrivacyRequest.get_cached_identity_data(pr_mixed)["email"] + == "mixed@example.com" + ) + + # Legacy request fully cleared + assert store_legacy.get_all_keys() == [] diff --git a/tests/common/cache/test_manager.py b/tests/common/cache/test_manager.py index f90db5e215a..fee4e95da10 100644 --- a/tests/common/cache/test_manager.py +++ b/tests/common/cache/test_manager.py @@ -1,139 +1,7 @@ -import fnmatch - import pytest from fides.common.cache.manager import INDEX_TTL_EXTRA_SECONDS, RedisCacheManager - -class MockPipeline: - """In-memory pipeline that batches commands and executes atomically.""" - - def __init__(self, data: dict, sets: dict) -> None: - self._data = data - self._sets = sets - self._commands: list = [] - - def set(self, key: str, value, ex=None) -> "MockPipeline": - self._commands.append(("set", (key, value, ex))) - return self - - def sadd(self, key: str, member: str) -> "MockPipeline": - self._commands.append(("sadd", (key, member))) - return self - - def delete(self, *keys: str) -> "MockPipeline": - self._commands.append(("delete", keys)) - return self - - def srem(self, key: str, member: str) -> "MockPipeline": - self._commands.append(("srem", (key, member))) - return self - - def execute(self) -> list: - results = [] - for op, args in self._commands: - if op == "set": - key, value, _ = args - self._data[key] = value - results.append(True) - elif op == "sadd": - key, member = args - if key not in self._sets: - self._sets[key] = set() - self._sets[key].add(member) - results.append(1) - elif op == "delete": - for k in args: - self._data.pop(k, None) - self._sets.pop(k, None) - results.append(len(args)) - elif op == "srem": - key, member = args - if key in self._sets: - self._sets[key].discard(member) - if not self._sets[key]: - del self._sets[key] - results.append(1) - self._commands = [] - return results - - -class MockRedis: - """In-memory Redis mock for RedisCacheManager tests.""" - - def __init__(self) -> None: - self._data: dict = {} - self._sets: dict = {} - self._ttl: dict = {} # key -> seconds until expiry (simplified; no decay) - - def get(self, key: str): - return self._data.get(key) - - def set(self, key: str, value, ex=None) -> bool: - self._data[key] = value - return True - - def delete(self, *keys: str) -> int: - count = 0 - for k in keys: - if k in self._data: - del self._data[k] - count += 1 - if k in self._sets: - del self._sets[k] - count += 1 - self._ttl.pop(k, None) - return count - - def sadd(self, key: str, member: str) -> int: - if key not in self._sets: - self._sets[key] = set() - self._sets[key].add(member) - return 1 - - def srem(self, key: str, member: str) -> int: - if key in self._sets: - self._sets[key].discard(member) - if not self._sets[key]: - del self._sets[key] - return 1 - return 0 - - def smembers(self, key: str) -> set: - return self._sets.get(key, set()).copy() - - def keys(self, pattern: str = "*") -> list: - all_keys = set(self._data) | set(self._sets) - return [k for k in all_keys if fnmatch.fnmatch(k, pattern)] - - def ttl(self, key: str) -> int: - if key not in self._data and key not in self._sets: - return -2 - return self._ttl.get(key, -1) - - def expire(self, key: str, seconds: int) -> bool: - if key in self._data or key in self._sets: - self._ttl[key] = seconds - return True - return False - - def pipeline(self) -> MockPipeline: - return MockPipeline(self._data, self._sets) - - -# --- Fixtures --- - - -@pytest.fixture -def mock_redis() -> MockRedis: - return MockRedis() - - -@pytest.fixture -def manager(mock_redis: MockRedis) -> RedisCacheManager: - return RedisCacheManager(mock_redis) - - # --- Tests --- @@ -142,7 +10,7 @@ class TestRedisCacheManagerPipeline: """Tests for RedisCacheManager pipeline-based index operations.""" def test_set_with_index_uses_pipeline_and_returns_set_result( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """set_with_index stores key, adds to index, and returns SET result.""" result = manager.set_with_index("k1", "v1", "idx1") @@ -152,7 +20,7 @@ def test_set_with_index_uses_pipeline_and_returns_set_result( assert "k1" in mock_redis.smembers("__idx:idx1") def test_set_with_index_with_expiry( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """set_with_index with expire_seconds stores value and adds to index.""" result = manager.set_with_index("k2", "v2", "idx2", expire_seconds=60) @@ -162,7 +30,7 @@ def test_set_with_index_with_expiry( assert "k2" in mock_redis.smembers("__idx:idx2") def test_delete_key_and_remove_from_index_atomic( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """delete_key_and_remove_from_index removes key and index entry atomically.""" manager.set_with_index("k3", "v3", "idx3") @@ -175,7 +43,7 @@ def test_delete_key_and_remove_from_index_atomic( assert "k3" not in mock_redis.smembers("__idx:idx3") def test_delete_keys_by_index_batches_deletes( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """delete_keys_by_index removes all indexed keys and the index in one pipeline.""" manager.set_with_index("k4a", "v4a", "idx4") @@ -190,7 +58,7 @@ def test_delete_keys_by_index_batches_deletes( assert mock_redis.smembers("__idx:idx4") == set() def test_delete_keys_by_index_empty_index( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """delete_keys_by_index on empty index deletes index set without error.""" manager.delete_keys_by_index("idx5") @@ -203,7 +71,7 @@ class TestRedisCacheManagerIndexOperations: """Tests for add_key_to_index, remove_key_from_index, get_keys_by_index, delete_index.""" def test_add_key_to_index_registers_key( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that add_key_to_index adds the key and creates the index set if it doesn't exist.""" manager.add_key_to_index("myidx", "cache_key_1") @@ -211,7 +79,7 @@ def test_add_key_to_index_registers_key( assert "cache_key_1" in mock_redis.smembers("__idx:myidx") def test_add_key_to_index_multiple_keys( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that add_key_to_index can add multiple keys to the same index.""" manager.add_key_to_index("idx6", "key_a") @@ -222,7 +90,7 @@ def test_add_key_to_index_multiple_keys( assert members == {"key_a", "key_b", "key_c"} def test_remove_key_from_index_idempotent( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that remove_key_from_index is idempotent and does not error when the specified key is not in the index.""" manager.set_with_index("key_a", "value_a", "idx6") @@ -243,7 +111,7 @@ def test_remove_key_from_index_idempotent( assert mock_redis.get("key_b") == "value_b" def test_remove_key_from_index_unregisters_key( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that remove_key_from_index removes a key from the index and does not remove other keys.""" manager.add_key_to_index("idx7", "keep") @@ -254,7 +122,7 @@ def test_remove_key_from_index_unregisters_key( assert mock_redis.smembers("__idx:idx7") == {"keep"} def test_remove_key_from_index_does_not_error_when_missing( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that remove_key_from_index does not error when the specified key is not in the index, and does not remove other keys.""" manager.add_key_to_index("idx8", "existing") @@ -264,7 +132,7 @@ def test_remove_key_from_index_does_not_error_when_missing( assert mock_redis.smembers("__idx:idx8") == {"existing"} def test_get_keys_by_index_returns_empty_for_missing_index( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that get_keys_by_index returns an empty list when the specified index does not exist.""" keys = manager.get_keys_by_index("never_used") @@ -272,7 +140,7 @@ def test_get_keys_by_index_returns_empty_for_missing_index( assert keys == [] def test_get_keys_by_index_returns_registered_keys( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure get_keys_by_index returns all the keys in the index.""" manager.add_key_to_index("idx9", "k1") @@ -284,7 +152,7 @@ def test_get_keys_by_index_returns_registered_keys( assert len(keys) == 2 def test_delete_index_removes_index_set_only( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that delete_index removes the index set but NOT the data keys that are still in the cache.""" mock_redis.set("data_key_1", "value1") @@ -296,7 +164,7 @@ def test_delete_index_removes_index_set_only( assert mock_redis.get("data_key_1") == "value1" def test_delete_index_does_not_error_when_empty( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Ensure that delete_index does not error when the specified index does not exist.""" manager.delete_index("nonexistent_idx") @@ -307,7 +175,7 @@ class TestRedisCacheManagerIndexTTL: """Tests for optional index TTL (index_ttl_enabled).""" def test_index_ttl_disabled_by_default( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Without index_ttl_enabled, index has no TTL.""" manager.set_with_index("k", "v", "idx", expire_seconds=60) @@ -315,7 +183,7 @@ def test_index_ttl_disabled_by_default( assert mock_redis.ttl("__idx:idx") == -1 def test_index_ttl_applied_when_enabled( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """With index_ttl_enabled, index gets TTL matching key.""" manager.set_with_index( @@ -325,7 +193,7 @@ def test_index_ttl_applied_when_enabled( assert mock_redis.ttl("__idx:idx") == 120 + INDEX_TTL_EXTRA_SECONDS def test_index_ttl_extended_when_key_ttl_farther_out( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Index TTL is pushed out when adding key with longer TTL.""" manager.set_with_index( @@ -340,7 +208,7 @@ def test_index_ttl_extended_when_key_ttl_farther_out( assert mock_redis.ttl("__idx:idx") == 300 + INDEX_TTL_EXTRA_SECONDS def test_index_ttl_not_shortened_when_key_ttl_shorter( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """Index TTL is NOT shortened when adding key with shorter TTL.""" manager.set_with_index( @@ -355,7 +223,7 @@ def test_index_ttl_not_shortened_when_key_ttl_shorter( assert mock_redis.ttl("__idx:idx") == 300 + INDEX_TTL_EXTRA_SECONDS def test_index_ttl_ignored_when_no_expire_seconds( - self, manager: RedisCacheManager, mock_redis: MockRedis + self, manager: RedisCacheManager, mock_redis ) -> None: """index_ttl_enabled has no effect when expire_seconds is not set.""" manager.set_with_index("k", "v", "idx", index_ttl_enabled=True) diff --git a/tests/ops/api/v1/endpoints/privacy_request/test_privacy_request_endpoints.py b/tests/ops/api/v1/endpoints/privacy_request/test_privacy_request_endpoints.py index e06d2943232..748439d170d 100644 --- a/tests/ops/api/v1/endpoints/privacy_request/test_privacy_request_endpoints.py +++ b/tests/ops/api/v1/endpoints/privacy_request/test_privacy_request_endpoints.py @@ -71,7 +71,6 @@ EMBEDDED_EXECUTION_LOG_LIMIT, ) from fides.api.tasks import DSR_QUEUE_NAME, MESSAGING_QUEUE_NAME -from fides.api.util.cache import get_encryption_cache_key from fides.api.util.data_category import get_user_data_categories from fides.api.util.encryption.secrets_util import SecretsUtil from fides.api.util.fuzzy_search_utils import ( @@ -788,7 +787,6 @@ def test_create_privacy_request_caches_encryption_keys( db, api_client: TestClient, policy, - cache, ): identity = {"email": "test@example.com"} data = [ @@ -804,11 +802,7 @@ def test_create_privacy_request_caches_encryption_keys( response_data = resp.json()["succeeded"] assert len(response_data) == 1 pr = PrivacyRequest.get(db=db, object_id=response_data[0]["id"]) - encryption_key = get_encryption_cache_key( - privacy_request_id=pr.id, - encryption_attr="key", - ) - assert cache.get(encryption_key) == "test--encryption" + assert pr.get_cached_encryption_key() == "test--encryption" pr.delete(db=db) assert run_access_request_mock.called @@ -8179,7 +8173,6 @@ def test_create_privacy_request_caches_encryption_keys( generate_auth_header, api_client: TestClient, policy, - cache, ): identity = {"email": "test@example.com"} data = [ @@ -8196,11 +8189,7 @@ def test_create_privacy_request_caches_encryption_keys( response_data = resp.json()["succeeded"] assert len(response_data) == 1 pr = PrivacyRequest.get(db=db, object_id=response_data[0]["id"]) - encryption_key = get_encryption_cache_key( - privacy_request_id=pr.id, - encryption_attr="key", - ) - assert cache.get(encryption_key) == "test--encryption" + assert pr.get_cached_encryption_key() == "test--encryption" assert run_access_request_mock.called def test_create_privacy_request_no_identities( diff --git a/tests/ops/api/v1/endpoints/test_drp_endpoints.py b/tests/ops/api/v1/endpoints/test_drp_endpoints.py index 7c6a15f37e6..f4ba2d7c78a 100644 --- a/tests/ops/api/v1/endpoints/test_drp_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_drp_endpoints.py @@ -18,7 +18,7 @@ PrivacyRequestDRPStatus, PrivacyRequestStatus, ) -from fides.api.util.cache import cache_task_tracking_key, get_drp_request_body_cache_key +from fides.api.util.cache import cache_task_tracking_key, get_dsr_cache_store from fides.common.scope_registry import ( POLICY_READ, PRIVACY_REQUEST_READ, @@ -50,7 +50,6 @@ def test_create_drp_privacy_request( db, api_client: TestClient, policy_drp_action, - cache, ): TEST_EMAIL = "test@example.com" TEST_PHONE_NUMBER = "+12345678910" @@ -76,26 +75,15 @@ def test_create_drp_privacy_request( pr = PrivacyRequest.get(db=db, object_id=response_data["request_id"]) # test appropriate data is cached - meta_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="meta", - ) - assert cache.get(meta_key) == "DrpMeta(version='0.5')" - regime_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="regime", - ) - assert cache.get(regime_key) == "ccpa" - exercise_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="exercise", - ) - assert cache.get(exercise_key) == "['access']" - identity_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="identity", - ) - assert cache.get(identity_key) == encoded_identity + store = get_dsr_cache_store(pr.id) + meta_value = store.get_drp("meta") + assert meta_value == "DrpMeta(version='0.5')" + regime_value = store.get_drp("regime") + assert regime_value == "ccpa" + exercise_value = store.get_drp("exercise") + assert exercise_value == "['access']" + identity_value = store.get_drp("identity") + assert identity_value == encoded_identity assert pr.get_cached_identity_data()["email"] == identity["email"] persisted_identity = pr.get_persisted_identity() @@ -115,7 +103,6 @@ def test_create_drp_privacy_request_unsupported_identity_props( db, api_client: TestClient, policy_drp_action, - cache, ): identity = {"email": "test@example.com", "address": "something"} encoded_identity: str = jwt.encode( @@ -136,26 +123,15 @@ def test_create_drp_privacy_request_unsupported_identity_props( pr = PrivacyRequest.get(db=db, object_id=response_data["request_id"]) # test appropriate data is cached - meta_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="meta", - ) - assert cache.get(meta_key) == "DrpMeta(version='0.5')" - regime_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="regime", - ) - assert cache.get(regime_key) == "ccpa" - exercise_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="exercise", - ) - assert cache.get(exercise_key) == "['access']" - identity_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="identity", - ) - assert cache.get(identity_key) == encoded_identity + store = get_dsr_cache_store(pr.id) + meta_value = store.get_drp("meta") + assert meta_value == "DrpMeta(version='0.5')" + regime_value = store.get_drp("regime") + assert regime_value == "ccpa" + exercise_value = store.get_drp("exercise") + assert exercise_value == "['access']" + identity_value = store.get_drp("identity") + assert identity_value == encoded_identity assert pr.get_cached_identity_data()["email"] == identity["email"] assert "address" not in pr.get_cached_identity_data().keys() @@ -305,7 +281,6 @@ def test_create_drp_privacy_request_error_notification( url, db, api_client: TestClient, - cache, policy_drp_action, ): TEST_EMAIL = "test@example.com" @@ -357,26 +332,15 @@ def test_create_drp_privacy_request_error_notification( pr = PrivacyRequest.get(db=db, object_id=response_data["request_id"]) # test appropriate data is cached - meta_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="meta", - ) - assert cache.get(meta_key) == "DrpMeta(version='0.5')" - regime_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="regime", - ) - assert cache.get(regime_key) == "ccpa" - exercise_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="exercise", - ) - assert cache.get(exercise_key) == "['access']" - identity_key = get_drp_request_body_cache_key( - privacy_request_id=pr.id, - identity_attribute="identity", - ) - assert cache.get(identity_key) == encoded_identity + store = get_dsr_cache_store(pr.id) + meta_value = store.get_drp("meta") + assert meta_value == "DrpMeta(version='0.5')" + regime_value = store.get_drp("regime") + assert regime_value == "ccpa" + exercise_value = store.get_drp("exercise") + assert exercise_value == "['access']" + identity_value = store.get_drp("identity") + assert identity_value == encoded_identity assert pr.get_cached_identity_data()["email"] == identity["email"] persisted_identity = pr.get_persisted_identity() diff --git a/tests/ops/models/privacy_request/test_privacy_request.py b/tests/ops/models/privacy_request/test_privacy_request.py index 6891d6aee16..801b5525745 100644 --- a/tests/ops/models/privacy_request/test_privacy_request.py +++ b/tests/ops/models/privacy_request/test_privacy_request.py @@ -47,6 +47,7 @@ FidesopsRedis, get_cache, get_custom_privacy_request_field_cache_key, + get_dsr_cache_store, get_identity_cache_key, ) from fides.api.util.constants import API_DATE_FORMAT @@ -257,34 +258,29 @@ def test_delete_privacy_request_removes_cached_data( privacy_request.delete(db) from_db = PrivacyRequest.get(db=db, object_id=privacy_request.id) assert from_db is None - assert cache.get(key) is None + # privacy_request.delete() calls clear_cached_values(), so cache is already cleared def test_cache_identity_fallback_to_db( db: Session, privacy_request_with_email_identity: PrivacyRequest, - cache: FidesopsRedis, loguru_caplog, ) -> None: identity = privacy_request_with_email_identity.get_persisted_identity() privacy_request_with_email_identity.cache_identity(identity) - key = get_identity_cache_key( - privacy_request_id=privacy_request_with_email_identity.id, - identity_attribute="email", - ) cached_identity_data = ( privacy_request_with_email_identity.get_cached_identity_data() ) assert cached_identity_data != {} - cache.delete(key) - assert cache.get(key) is None + store = get_dsr_cache_store(privacy_request_with_email_identity.id) + store.delete("identity:email") assert ( privacy_request_with_email_identity.get_cached_identity_data() == cached_identity_data ) assert ( f"Cache miss for request {privacy_request_with_email_identity.id}, falling back to DB" - in loguru_caplog.messages[-1] + in loguru_caplog.text ) @@ -312,7 +308,7 @@ def test_cache_identity_fallback_to_db_no_persisted_identity( assert privacy_request.get_cached_identity_data() == {} assert ( f"Cache miss for request {privacy_request.id}, falling back to DB" - in loguru_caplog.messages[-1] + in loguru_caplog.text ) @@ -341,8 +337,9 @@ def test_custom_privacy_request_fields_fallback_to_db( privacy_request.get_cached_custom_privacy_request_fields() ) assert cached_custom_privacy_request_fields is not None - cache.delete(key) - assert cache.get(key) is None + # Delete using DSR store to clear the cached custom field + store = get_dsr_cache_store(privacy_request.id) + store.delete(f"custom_field:{custom_privacy_request_field.label}") assert ( privacy_request.get_cached_custom_privacy_request_fields() == cached_custom_privacy_request_fields @@ -1332,6 +1329,7 @@ def test_old_cache_can_be_read(self, privacy_request): We need to make sure we can still read these old values using the new `get_cached_identity_data` function. """ + privacy_request.clear_cached_values() def cache_identity(identity: Identity, privacy_request_id: str) -> None: """Old function for caching identity""" diff --git a/tests/ops/service/privacy_request/test_request_service.py b/tests/ops/service/privacy_request/test_request_service.py index 2cbe0e28b08..08ecebb6657 100644 --- a/tests/ops/service/privacy_request/test_request_service.py +++ b/tests/ops/service/privacy_request/test_request_service.py @@ -341,10 +341,9 @@ def very_short_request_task_expiration(): @pytest.fixture(scope="function") def very_short_redis_cache_expiration(): - original_value: float = CONFIG.redis.default_ttl_seconds - CONFIG.redis.default_ttl_seconds = ( - 0.01 # Set redis cache to expire very quickly for testing purposes - ) + original_value: int = CONFIG.redis.default_ttl_seconds + # Redis SET ex= must be int or timedelta (not float). Use 2s to avoid flakiness on slow CI. + CONFIG.redis.default_ttl_seconds = 2 yield CONFIG CONFIG.redis.default_ttl_seconds = original_value @@ -355,7 +354,7 @@ class TestRemoveSavedCustomerData: ) def test_no_request_tasks(self, db, privacy_request): assert not privacy_request.request_tasks.count() - time.sleep(1) + time.sleep(3) # Mainly asserting this runs without error remove_saved_dsr_data.delay().get() @@ -381,7 +380,7 @@ def test_privacy_request_incomplete(self, db, privacy_request): privacy_request.save(db) assert privacy_request.request_tasks.count() - time.sleep(1) + time.sleep(3) remove_saved_dsr_data.delay().get() @@ -409,7 +408,7 @@ def test_customer_data_removed_from_old_request_tasks_and_privacy_requests( privacy_request.save(db) assert privacy_request.request_tasks.count() - time.sleep(1) + time.sleep(3) remove_saved_dsr_data.delay().get() @@ -590,16 +589,18 @@ def test_get_cached_task_id_none_when_not_cached(self, privacy_request): result = get_cached_task_id(privacy_request.id) assert result is None - @mock.patch("fides.api.service.privacy_request.request_service.get_cache") + @mock.patch("fides.api.service.privacy_request.request_service.get_dsr_cache_store") @mock.patch("fides.api.service.privacy_request.request_service.logger") def test_get_cached_task_id_cache_exception( - self, mock_logger, mock_get_cache, privacy_request + self, mock_logger, mock_get_store, privacy_request ): """Test that function logs error and re-raises exceptions from cache operations.""" - # Mock cache to raise exception - mock_cache = mock.Mock() - mock_cache.get.side_effect = Exception("Redis connection failed") - mock_get_cache.return_value = mock_cache + # Mock store to raise exception on get_async_execution + mock_store = mock.Mock() + mock_store.get_async_execution.side_effect = Exception( + "Redis connection failed" + ) + mock_get_store.return_value = mock_store # Function should log error and re-raise exception with pytest.raises(Exception, match="Redis connection failed"): diff --git a/tests/ops/tasks/test_encryption_utils.py b/tests/ops/tasks/test_encryption_utils.py index 71648068319..fb3bec87174 100644 --- a/tests/ops/tasks/test_encryption_utils.py +++ b/tests/ops/tasks/test_encryption_utils.py @@ -7,29 +7,31 @@ @pytest.fixture def mock_cache(): - with patch("fides.api.tasks.encryption_utils.get_cache") as mock_get_cache: - cache = MagicMock() - mock_get_cache.return_value = cache - yield cache + with patch( + "fides.api.tasks.encryption_utils.get_dsr_cache_store" + ) as mock_get_store: + store = MagicMock() + mock_get_store.return_value = store + yield store def test_encrypt_access_request_results_no_encryption_key(mock_cache): """Test that data is returned unencrypted when no encryption key is found in cache.""" - mock_cache.get.return_value = None + mock_cache.get_encryption.return_value = None test_data = "test_data" request_id = "test_request_id" result = encrypt_access_request_results(test_data, request_id) assert result == test_data - mock_cache.get.assert_called_once() + mock_cache.get_encryption.assert_called_once_with("key") def test_encrypt_access_request_results_with_encryption_key(mock_cache): """Test that data is encrypted when encryption key is found in cache.""" # Use a 16-byte key (128 bits) for AES-GCM encryption_key = "0123456789abcdef" # 16 bytes - mock_cache.get.return_value = encryption_key + mock_cache.get_encryption.return_value = encryption_key test_data = "test_data" request_id = "test_request_id" @@ -38,13 +40,13 @@ def test_encrypt_access_request_results_with_encryption_key(mock_cache): # The result should be a base64 encoded string containing the nonce and encrypted data assert isinstance(result, str) assert len(result) > 0 - mock_cache.get.assert_called_once() + mock_cache.get_encryption.assert_called_once_with("key") def test_encrypt_access_request_results_with_bytes_input(mock_cache): """Test that bytes input is properly handled and encrypted.""" encryption_key = "0123456789abcdef" # 16 bytes - mock_cache.get.return_value = encryption_key + mock_cache.get_encryption.return_value = encryption_key test_data = b"test_data" request_id = "test_request_id" @@ -52,13 +54,13 @@ def test_encrypt_access_request_results_with_bytes_input(mock_cache): assert isinstance(result, str) assert len(result) > 0 - mock_cache.get.assert_called_once() + mock_cache.get_encryption.assert_called_once_with("key") def test_encrypt_access_request_results_empty_data(mock_cache): """Test that empty data is handled correctly.""" encryption_key = "0123456789abcdef" # 16 bytes - mock_cache.get.return_value = encryption_key + mock_cache.get_encryption.return_value = encryption_key test_data = "" request_id = "test_request_id" @@ -66,13 +68,13 @@ def test_encrypt_access_request_results_empty_data(mock_cache): assert isinstance(result, str) assert len(result) > 0 - mock_cache.get.assert_called_once() + mock_cache.get_encryption.assert_called_once_with("key") def test_encrypt_access_request_results_special_characters(mock_cache): """Test that data with special characters is properly encrypted.""" encryption_key = "0123456789abcdef" # 16 bytes - mock_cache.get.return_value = encryption_key + mock_cache.get_encryption.return_value = encryption_key test_data = "test_data!@#$%^&*()_+" request_id = "test_request_id" @@ -80,4 +82,4 @@ def test_encrypt_access_request_results_special_characters(mock_cache): assert isinstance(result, str) assert len(result) > 0 - mock_cache.get.assert_called_once() + mock_cache.get_encryption.assert_called_once_with("key") diff --git a/tests/ops/test_helpers/cache_secrets_helper.py b/tests/ops/test_helpers/cache_secrets_helper.py index c0694d98657..b19639980a9 100644 --- a/tests/ops/test_helpers/cache_secrets_helper.py +++ b/tests/ops/test_helpers/cache_secrets_helper.py @@ -1,5 +1,10 @@ from fides.api.schemas.masking.masking_secrets import MaskingSecretCache -from fides.api.util.cache import FidesopsRedis, get_cache, get_masking_secret_cache_key +from fides.api.util.cache import ( + FidesopsRedis, + get_cache, + get_dsr_cache_store, + get_masking_secret_cache_key, +) def cache_secret(masking_secret_cache: MaskingSecretCache, request_id: str) -> None: @@ -20,9 +25,15 @@ def clear_cache_secrets(request_id: str) -> None: def clear_cache_identities(request_id: str) -> None: - """Testing helper just removes some cached identities from the Privacy Request for testing. + """Testing helper that removes cached identities from the Privacy Request. Some of our Privacy Request fixtures automatically cache identities - + this clears them using the DSR cache store. The get_cached_identity_data + call migrates any legacy keys before deletion. """ - cache: FidesopsRedis = get_cache() - cache.delete_keys_by_prefix(f"id-{request_id}-identity-") + store = get_dsr_cache_store(request_id) + # get_cached_identity_data triggers migration (legacy → new), so all + # identity keys will be in new format after this call. + identity_data = store.get_cached_identity_data() + for attr in identity_data: + store.delete(f"identity:{attr}") diff --git a/tests/ops/util/test_cache.py b/tests/ops/util/test_cache.py index 7914f9e90ce..b379646a055 100644 --- a/tests/ops/util/test_cache.py +++ b/tests/ops/util/test_cache.py @@ -209,7 +209,12 @@ def test_cache_tracking_key_has_ttl(self, privacy_request): cache_task_tracking_key(privacy_request.id, "test_1234") raw_cache = get_cache() - ttl = raw_cache.ttl(get_async_task_tracking_cache_key(privacy_request.id)) + # Check new-format key; fall back to legacy key for backward compat + new_key = f"dsr:{privacy_request.id}:async_execution" + legacy_key = get_async_task_tracking_cache_key(privacy_request.id) + ttl = raw_cache.ttl(new_key) + if ttl == -2: + ttl = raw_cache.ttl(legacy_key) assert ttl > 0 def test_cache_tracking_key_request_task(self, request_task):