diff --git a/rfc9421.py b/rfc9421.py deleted file mode 100644 index 64c024e..0000000 --- a/rfc9421.py +++ /dev/null @@ -1,374 +0,0 @@ -import base64 -import datetime as dt -import email.utils -import json -from typing import Any, List, Tuple, cast - -import pytz -from cryptography.hazmat.primitives import hashes, serialization -from cryptography.hazmat.primitives.asymmetric import ec, ed25519, padding, rsa -from http_sf import parse -from http_sf.types import ( - BareItemType, - InnerListType, - ItemType, - ParamsType, -) -from multiformats import multibase, multicodec - -from apsig.draft.tools import calculate_digest -from apsig.exceptions import MissingSignature, VerificationFailed - - -class RFC9421Signer: - def __init__(self, private_key: rsa.RSAPrivateKey, key_id: str): - self.private_key = private_key - self.key_id = key_id - self.sign_headers = [ - "date", - "@method", - "@path", - "@authority", - "content-type", - "content-length", - ] - - def __build_signature_base( - self, special_keys: dict[str, str], headers: dict[str, str] - ) -> bytes: - headers_new = [] - headers = headers.copy() - for h in self.sign_headers: - if h in ["@method", "@path", "@authority"]: - v = special_keys.get(h) - - if v: - headers_new.append(f'"{h}": {v}') - else: - raise ValueError(f"Missing Value: {h}") - elif h == "@signature-params": - v = special_keys.get(h) - - if v: - headers_new.append(f'"{h}": {self.__generate_sig_input()}') - else: - raise ValueError(f"Missing Value: {h}") - else: - v_raw = headers.get(h) - if v_raw is not None: - v = v_raw.strip() - headers_new.append(f'"{h}": {v}') - else: - raise ValueError(f"Missing Header Value: {h}") - headers_new.append(f'"@signature-params": {self.__generate_sig_input()}') - return ("\n".join(headers_new)).encode("utf-8") - - def generate_signature_header(self, signature: bytes) -> str: - return base64.b64encode(signature).decode("utf-8") - - def __generate_sig_input(self): - param = "(" - target_len = len(self.sign_headers) - timestamp = dt.datetime.now(dt.UTC) - for p in self.sign_headers: - param += f'"{p}"' - if p != self.sign_headers[target_len - 1]: - param += " " - param += ");" - param += f"created={int(timestamp.timestamp())};" - param += 'alg="rsa-v1_5-sha256";' - param += f'keyid="{self.key_id}"' - return param - - def sign( - self, - method: str, - path: str, - host: str, - headers: dict, - body: bytes | dict = b"", - ): - if isinstance(body, dict): - body = json.dumps(body).encode("utf-8") - - headers = {k.lower(): v for k, v in headers.items()} - if not headers.get("date"): - headers["date"] = email.utils.formatdate(usegmt=True) - if not headers.get("content-length"): - headers["content-length"] = str(len(body)) - - special_keys = { - "@method": method.upper(), - "@path": path, - "@authority": host, - } - - base = self.__build_signature_base(special_keys, headers) - signed = self.private_key.sign(base, padding.PKCS1v15(), hashes.SHA256()) - headers_req = headers.copy() - headers_req["Signature"] = f"sig1=:{self.generate_signature_header(signed)}:" - headers_req["content-digest"] = f"sha-256=:{calculate_digest(body)}:" - headers_req["Signature-Input"] = f"sig1={self.__generate_sig_input()}" - return headers_req - - -class RFC9421Verifier: - def __init__( - self, - public_key: ed25519.Ed25519PublicKey - | rsa.RSAPublicKey - | ec.EllipticCurvePublicKey - | str, - method: str, - path: str, - host: str, - headers: dict[str, str], - body: bytes | dict | None = None, - clock_skew: int = 300, - ): - self.public_key: ( - ed25519.Ed25519PublicKey | rsa.RSAPublicKey | ec.EllipticCurvePublicKey - ) - - if isinstance(public_key, str): - codec, data = multicodec.unwrap(multibase.decode(public_key)) - match codec.name: - case "ed25519-pub": - self.public_key: ed25519.Ed25519PublicKey = ( - ed25519.Ed25519PublicKey.from_public_bytes(data) - ) - case "rsa-pub": - pubkey = serialization.load_pem_public_key(data) - if not isinstance(pubkey, rsa.RSAPublicKey): - raise TypeError("PublicKey must be ed25519 or RSA or ECDSA.") - self.public_key = pubkey - case "p256-pub": - pubkey = serialization.load_pem_public_key(data) - if not isinstance(pubkey, ec.EllipticCurvePublicKey): - raise TypeError("PublicKey must be ed25519 or RSA or ECDSA.") - self.public_key = pubkey - case "p384-pub": - pubkey = serialization.load_pem_public_key(data) - if not isinstance(pubkey, ec.EllipticCurvePublicKey): - raise TypeError("PublicKey must be ed25519 or RSA or ECDSA.") - self.public_key = pubkey - case _: - raise TypeError("PublicKey must be ed25519 or RSA or ECDSA.") - else: - self.public_key: ed25519.Ed25519PublicKey | rsa.RSAPublicKey = public_key - self.clock_skew = clock_skew - self.method = method.upper() - self.path = path - self.host = host - self.headers = {key.lower(): value for key, value in headers.items()} - - def __expect_value_and_params_member( - self, - member: Any, - ) -> Tuple[ItemType | InnerListType, ParamsType]: - if not isinstance(member, tuple) or len(member) != 2: - raise ValueError("expected a (value, params) tuple") - value, params = member - if not isinstance(params, dict): - raise ValueError("expected params to be a dict") - return cast(Tuple[ItemType | InnerListType, ParamsType], (value, params)) - - def __generate_sig_input( - self, headers: List[BareItemType], params: ParamsType - ) -> str: - created = params.get("created") - alg = params.get("alg") - keyid = params.get("keyid") - - if isinstance(created, dt.datetime): - created_timestamp = created - elif isinstance(created, int): - created_timestamp = dt.datetime.fromtimestamp(created) - elif isinstance(created, str): - created_timestamp = dt.datetime.fromtimestamp(int(created)) - else: - raise ValueError("Unknown created value") - request_time = created_timestamp.astimezone(pytz.utc) - current_time = dt.datetime.now(dt.UTC) - if abs((current_time - request_time).total_seconds()) > self.clock_skew: - raise VerificationFailed( - f"property created is too far from current time ({current_time}): {request_time}" - ) - - param = "(" - target_len = len(headers) - for p in headers: - param += f'"{p}"' - if p != headers[target_len - 1]: - param += " " - param += ");" - param += f"created={int(created_timestamp.timestamp())};" - param += f'alg="{alg}";' - param += f'keyid="{keyid}"' - return param - - def __rebuild_sigbase( - self, headers: List[BareItemType], params: ParamsType - ) -> bytes: - special_keys = { - "@method": self.method, - "@path": self.path, - "@authority": self.host, - } - base = [] - for h in cast(List[str], headers): - if h in ["@method", "@path", "@authority"]: - base.append(f'"{h}": {special_keys.get(h)}') - else: - v_raw = self.headers.get(h) - if v_raw is not None: - v = v_raw.strip() - base.append(f'"{h}": {v}') - else: - raise ValueError(f"Missing Header Value: {h}") - base.append( - f'"@signature-params": {self.__generate_sig_input(headers=headers, params=params)}' - ) - return ("\n".join(base)).encode("utf-8") - - def verify(self, raise_on_fail: bool = False) -> str | None: - signature = self.headers.get("signature") - if not signature: - if raise_on_fail: - raise MissingSignature("Signature header is missing") - return None - - signature_input = self.headers.get("signature-input") - if not signature_input: - if raise_on_fail: - raise MissingSignature("Signature-Input header is missing") - return None - - signature_input_parsed = parse( - signature_input.encode("utf-8"), tltype="dictionary" - ) - signature_parsed = parse(signature.encode("utf-8"), tltype="dictionary") - - if not isinstance(signature_input_parsed, dict): - raise VerificationFailed( - f"Unsupported Signature-Input type: {type(signature_input_parsed)}" - ) - - if not isinstance(signature_parsed, dict): - raise VerificationFailed( - f"Unsupported Signature type: {type(signature_parsed)}" - ) - - for k, v in signature_input_parsed.items(): - try: - value, params = self.__expect_value_and_params_member(v) - if isinstance(value, list): - headers: List[BareItemType] = [ - itm[0] if isinstance(itm, tuple) else itm for itm in value - ] - else: - raise ValueError( - "expected the value to be an inner-list (list of items)" - ) - - created = params.get("created") - key_id = str(params.get("keyid")) - alg = params.get("alg") - - if not created: - raise VerificationFailed("created not found.") - if not key_id: - raise VerificationFailed("keyid not found.") - if not alg: - raise VerificationFailed("alg not found.") - if alg not in [ - "ed25519", - "rsa-v1_5-sha256", - "rsa-v1_5-sha512", - "rsa-pss-sha512", - ]: - raise VerificationFailed(f"Unsupported algorithm: {alg}") - - sigi = self.__rebuild_sigbase(headers, params) - signature_bytes = signature_parsed.get(k) - if not isinstance(signature_bytes, tuple): - raise VerificationFailed( - f"Unknown Signature: {type(signature_bytes)}" - ) - - sig_val = None - for sig in cast(InnerListType, signature_bytes): - if isinstance(sig, bytes): - sig_val = sig - break - if sig_val is None: - raise ValueError("No Signature found.") - try: - match alg: - case "ed25519": - if not isinstance( - self.public_key, ed25519.Ed25519PublicKey - ): - raise VerificationFailed("Algorithm missmatch.") - self.public_key.verify(sig_val, sigi) - case "rsa-v1_5-sha256": - if not isinstance(self.public_key, rsa.RSAPublicKey): - raise VerificationFailed("Algorithm missmatch.") - self.public_key.verify( - sig_val, - sigi, - padding.PKCS1v15(), - hashes.SHA256(), - ) - case "rsa-v1_5-sha512": - if not isinstance(self.public_key, rsa.RSAPublicKey): - raise VerificationFailed("Algorithm missmatch.") - self.public_key.verify( - sig_val, - sigi, - padding.PKCS1v15(), - hashes.SHA512(), - ) - case "rsa-pss-sha512": - if not isinstance(self.public_key, rsa.RSAPublicKey): - raise VerificationFailed("Algorithm missmatch.") - self.public_key.verify( - sig_val, - sigi, - padding.PSS( - mgf=padding.MGF1(hashes.SHA512()), - salt_length=hashes.SHA512().digest_size, - ), - hashes.SHA512(), - ) - case "ecdsa-p256-sha256": - if not isinstance( - self.public_key, ec.EllipticCurvePublicKey - ): - raise VerificationFailed("Algorithm missmatch.") - self.public_key.verify( - sig_val, - sigi, - ec.ECDSA(hashes.SHA256()), - ) - case "ecdsa-p384-sha384": - if not isinstance( - self.public_key, ec.EllipticCurvePublicKey - ): - raise VerificationFailed("Algorithm missmatch.") - self.public_key.verify( - sig_val, - sigi, - ec.ECDSA(hashes.SHA384()), - ) - return key_id - except Exception as e: - if raise_on_fail: - raise VerificationFailed(str(e)) - return None - except ValueError: - continue - - if raise_on_fail: - raise VerificationFailed("RFC9421 Signature verification failed.") - return None diff --git a/src/apkit/_version.py b/src/apkit/_version.py index bea3035..878175f 100644 --- a/src/apkit/_version.py +++ b/src/apkit/_version.py @@ -28,7 +28,7 @@ commit_id: COMMIT_ID __commit_id__: COMMIT_ID -__version__ = version = '0.3.3.post1.dev36+gc65fc9498' -__version_tuple__ = version_tuple = (0, 3, 3, 'post1', 'dev36', 'gc65fc9498') +__version__ = version = '0.3.3.post1.dev58+ge5df88041.d20260127' +__version_tuple__ = version_tuple = (0, 3, 3, 'post1', 'dev58', 'ge5df88041.d20260127') __commit_id__ = commit_id = None diff --git a/src/apkit/cache.py b/src/apkit/cache.py index 975ddd9..eecff16 100644 --- a/src/apkit/cache.py +++ b/src/apkit/cache.py @@ -1,4 +1,3 @@ -import time from typing import Any, Generic, Optional from .kv import KT, VT, KeyValueStore @@ -6,18 +5,9 @@ class Cache(Generic[KT, VT]): """ - A generic cache wrapper that uses a KeyValueStore as a backend - and adds Time-To-Live (TTL) support. + A generic cache wrapper that uses a KeyValueStore as a backend. """ - class _CacheItem(Generic[VT]): - value: VT - expiration: float - - def __init__(self, value: VT, expiration: float): - self.value = value - self.expiration = expiration - def __init__(self, store: Optional[KeyValueStore[KT, Any]]): self._store = store @@ -26,15 +16,8 @@ def get(self, key: KT) -> VT | None: Gets an item from the cache, returning None if it's expired or doesn't exist. """ if self._store: - item = self._store.get(key) - if not (hasattr(item, "value") and hasattr(item, "expiration")): - return None - - if time.time() > item.expiration: - self._store.delete(key) - return None - - return item.value + return self._store.get(key) + return None def set(self, key: KT, value: VT, ttl: float | None) -> None: """ @@ -42,15 +25,15 @@ def set(self, key: KT, value: VT, ttl: float | None) -> None: If ttl is None, the item will not expire. """ if self._store: - if ttl is not None: - if ttl <= 0: - self._store.delete(key) - return - expiration = time.time() + ttl - else: - expiration = float("inf") + if ttl is not None and ttl <= 0: + self._store.delete(key) + return + + ttl_int = int(ttl) if ttl is not None else None + if ttl is not None and ttl > 0 and ttl_int == 0: + ttl_int = 1 - self._store.set(key, self._CacheItem(value, expiration)) + self._store.set(key, value, ttl_seconds=ttl_int) def delete(self, key: KT) -> None: """Deletes an item from the cache.""" @@ -62,31 +45,16 @@ def exists(self, key: KT) -> bool: Checks if a non-expired item exists in the cache. """ if self._store: - item = self._store.get(key) - if not (hasattr(item, "value") and hasattr(item, "expiration")): - return False - - if time.time() > item.expiration: - self._store.delete(key) - return False - return True - else: - return False + return self._store.exists(key) + return False async def async_get(self, key: KT) -> VT | None: """ Gets an item from the cache, returning None if it's expired or doesn't exist. """ if self._store: - item = self._store.get(key) - if not (hasattr(item, "value") and hasattr(item, "expiration")): - return None - - if time.time() > item.expiration: - self._store.delete(key) - return None - - return item.value + return await self._store.async_get(key) + return None async def async_set(self, key: KT, value: VT, ttl: float | None) -> None: """ @@ -94,33 +62,25 @@ async def async_set(self, key: KT, value: VT, ttl: float | None) -> None: If ttl is None, the item will not expire. """ if self._store: - if ttl is not None: - if ttl <= 0: - self._store.delete(key) - return - expiration = time.time() + ttl - else: - expiration = float("inf") + if ttl is not None and ttl <= 0: + await self._store.async_delete(key) + return + + ttl_int = int(ttl) if ttl is not None else None + if ttl is not None and ttl > 0 and ttl_int == 0: + ttl_int = 1 - self._store.set(key, self._CacheItem(value, expiration)) + await self._store.async_set(key, value, ttl_seconds=ttl_int) async def async_delete(self, key: KT) -> None: """Deletes an item from the cache.""" if self._store: - self._store.delete(key) + await self._store.async_delete(key) async def async_exists(self, key: KT) -> bool: """ Checks if a non-expired item exists in the cache. """ if self._store: - item = self._store.get(key) - if not (hasattr(item, "value") and hasattr(item, "expiration")): - return False - - if time.time() > item.expiration: - self._store.delete(key) - return False - - return True + return await self._store.async_exists(key) return False diff --git a/src/apkit/helper/inbox.py b/src/apkit/helper/inbox.py index b45ef3a..fd93428 100644 --- a/src/apkit/helper/inbox.py +++ b/src/apkit/helper/inbox.py @@ -6,8 +6,10 @@ import http_sf from apmodel import Activity from apmodel.core.link import Link +from apmodel.extra.cid import Multikey +from apmodel.extra.security import CryptographicKey from apmodel.vocab.actor import Actor -from apsig import KeyUtil, LDSignature, ProofVerifier +from apsig import LDSignature, ProofVerifier from apsig.draft.verify import Verifier from apsig.exceptions import ( MissingSignatureError, @@ -51,13 +53,49 @@ def __get_draft_signature_parts(self, signature: str) -> dict[Any, Any]: signature_parts[key.strip()] = value.strip().strip('"') return signature_parts - async def __get_signature_from_kv(self, key_id: str) -> tuple[Optional[str], bool]: + async def __get_signature_from_kv( + self, key_id: str + ) -> tuple[ + Optional[RSAPublicKey | Ed25519PublicKey | ec.EllipticCurvePublicKey], + bool, + ]: cache = False - public_key = await self.config.kv.async_get(f"signature:{key_id}") + public_key = await self.config.cache.async_get(f"signature:{key_id}") if public_key: self.logger.debug("Use existing cached keys") cache = True - return public_key, cache + if isinstance(public_key, bytes): + try: + key = serialization.load_der_public_key(public_key) + if isinstance( + key, + ( + RSAPublicKey, + Ed25519PublicKey, + ec.EllipticCurvePublicKey, + ), + ): + return key, cache + except Exception: + self.logger.warning( + f"Failed to load cached key {key_id}", exc_info=True + ) + return None, False + + async def __save_signature_to_kv(self, key_id: str, public_key: Any) -> None: + if isinstance( + public_key, + ( + rsa.RSAPublicKey, + ec.EllipticCurvePublicKey, + ed25519.Ed25519PublicKey, + ), + ): + der = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) + await self.config.kv.async_set(f"signature:{key_id}", der) async def __verify_rfc9421( self, @@ -100,22 +138,7 @@ async def _verify_with_key( verified_key_id = verifier.verify(raise_on_fail=True) if verified_key_id: if not is_cache: - if isinstance( - public_key_obj, - (rsa.RSAPublicKey, ec.EllipticCurvePublicKey), - ): - pem = public_key_obj.public_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PublicFormat.SubjectPublicKeyInfo, - ) - await self.config.kv.async_set( - f"signature:{key_id}", pem.decode("utf-8") - ) - elif isinstance(public_key_obj, ed25519.Ed25519PublicKey): - ku = KeyUtil(public_key_obj) - await self.config.kv.async_set( - f"signature:{key_id}", ku.encode_multibase() - ) + await self.__save_signature_to_kv(key_id, public_key_obj) return True except VerificationFailedError as e: if is_cache: @@ -134,20 +157,9 @@ async def _verify_with_key( # Try with cached key first public_key_obj = None - cache = False if not no_check_cache: - public_key_pem, cache = await self.__get_signature_from_kv(key_id) - if public_key_pem: - try: - ku = KeyUtil() - public_key_obj = ku.decode_multibase(public_key_pem) - except Exception: - try: - public_key_obj = serialization.load_pem_public_key( - public_key_pem.encode("utf-8") - ) - except ValueError: - self.logger.warning(f"Failed to load cached key {key_id}") + public_key_obj, _ = await self.__get_signature_from_kv(key_id) + if public_key_obj and await _verify_with_key(key_id, public_key_obj, True): return True @@ -187,13 +199,19 @@ async def __verify_draft( public_key = public_keys.get_key(key_id) else: public_key = None - if ( - public_key - and not isinstance(public_key, str) - and isinstance(public_key.public_key, RSAPublicKey) - ): + if public_key: + if isinstance(public_key, str): + public_key = serialization.load_pem_public_key( + public_key.encode("utf-8") + ) + elif isinstance(public_key, (CryptographicKey, Multikey)): + public_key = public_key.public_key + + if not isinstance(public_key, RSAPublicKey): + raise TypeError(f"Unsupported key type: {type(public_key)}") + verifier = Verifier( - public_key.public_key, + public_key, method, str(url), headers, @@ -201,14 +219,8 @@ async def __verify_draft( ) try: verifier.verify(raise_on_fail=True) - if isinstance(public_key.public_key, rsa.RSAPublicKey): - await self.config.kv.async_set( - f"signature:{key_id}", - public_key.public_key.public_bytes( - serialization.Encoding.PEM, - serialization.PublicFormat.SubjectPublicKeyInfo, - ).decode("utf-8"), - ) + if not cache: + await self.__save_signature_to_kv(key_id, public_key) return True except Exception as e: if not cache: @@ -252,12 +264,9 @@ async def __verify_proof(self, body: bytes, no_check_cache: bool = False) -> boo try: proof.verify(body_json) if not cache: - if isinstance(public_key, ed25519.Ed25519PublicKey): - ku = KeyUtil(public_key) - await self.config.kv.async_set( - f"signature:{verification_method}", - ku.encode_multibase(), - ) + await self.__save_signature_to_kv( + verification_method, public_key + ) return True except Exception as e: if not cache: @@ -295,24 +304,15 @@ async def __verify_ld(self, body: bytes, no_check_cache: bool = False) -> bool: public_keys = await self.__fetch_actor(activity) if public_keys and creator: public_key = public_keys.get_key(creator) - if public_key: + if public_key and creator: try: public_key = ( - public_key.public_key - if not isinstance(public_key, str) - else public_key + public_key if not isinstance(public_key, str) else public_key ) - if public_key and not isinstance(public_key, Ed25519PublicKey): + if public_key and isinstance(public_key, RSAPublicKey): ld.verify(body_json, public_key, raise_on_fail=True) if not cache: - if isinstance(public_key, rsa.RSAPublicKey): - await self.config.kv.async_set( - f"signature:{creator}", - public_key.public_bytes( - serialization.Encoding.PEM, - serialization.PublicFormat.SubjectPublicKeyInfo, - ), - ) + await self.__save_signature_to_kv(creator, public_key) return True raise VerificationFailedError("publicKey does not exist.") except ( diff --git a/src/apkit/kv/inmemory.py b/src/apkit/kv/inmemory.py index 6fdddf2..9cf86df 100644 --- a/src/apkit/kv/inmemory.py +++ b/src/apkit/kv/inmemory.py @@ -1,4 +1,5 @@ import time +from collections import OrderedDict from typing import Any from . import KeyValueStore @@ -6,11 +7,60 @@ class InMemoryKV(KeyValueStore[Any, Any]): """ - An in-memory key-value store implementation with TTL support. + An in-memory key-value store implementation with TTL and LRU support. """ def __init__(self) -> None: self._store: dict[Any, tuple[Any, float | None]] = {} + self._lru_configs: dict[str, int | None] = {} + self._lru_keys: dict[str, OrderedDict[Any, None]] = {} + + def configure_lru(self, namespace: str, max_size: int | None = None) -> None: + """Configures LRU settings for a specific namespace.""" + self._lru_configs[namespace] = max_size + if namespace not in self._lru_keys: + self._lru_keys[namespace] = OrderedDict() + + if max_size is not None: + self._enforce_lru(namespace, max_size) + + def _get_namespace(self, key: Any) -> str: + if isinstance(key, str) and ":" in key: + return key.split(":", 1)[0] + return "default" + + def _update_lru_on_access(self, key: Any) -> None: + namespace = self._get_namespace(key) + if namespace in self._lru_keys and key in self._lru_keys[namespace]: + self._lru_keys[namespace].move_to_end(key) + + def _update_lru_on_set(self, key: Any) -> None: + namespace = self._get_namespace(key) + max_size = self._lru_configs.get(namespace) + + if namespace not in self._lru_keys: + self._lru_keys[namespace] = OrderedDict() + + self._lru_keys[namespace][key] = None + self._lru_keys[namespace].move_to_end(key) + + if max_size is not None: + self._enforce_lru(namespace, max_size) + + def _remove_from_lru(self, key: Any) -> None: + namespace = self._get_namespace(key) + if namespace in self._lru_keys and key in self._lru_keys[namespace]: + del self._lru_keys[namespace][key] + + def _enforce_lru(self, namespace: str, max_size: int) -> None: + keys = self._lru_keys.get(namespace) + if keys is None: + return + + while len(keys) > max_size: + oldest_key, _ = keys.popitem(last=False) + if oldest_key in self._store: + del self._store[oldest_key] def get(self, key: Any) -> Any | None: """Gets a value from the in-memory store, checking for TTL.""" @@ -19,20 +69,23 @@ def get(self, key: Any) -> Any | None: value, expires_at = self._store[key] if expires_at is not None and expires_at < time.time(): - del self._store[key] + self.delete(key) return None + self._update_lru_on_access(key) return value def set(self, key: Any, value: Any, ttl_seconds: int | None = 3600) -> None: """Sets a value in the in-memory store with an optional TTL.""" expires_at = time.time() + ttl_seconds if ttl_seconds is not None else None - self._store[key] = (value, (expires_at)) + self._store[key] = (value, expires_at) + self._update_lru_on_set(key) def delete(self, key: Any) -> None: """Deletes a key from the in-memory store.""" if key in self._store: del self._store[key] + self._remove_from_lru(key) def exists(self, key: Any) -> bool: """Checks if a key exists in the in-memory store, considering TTL.""" @@ -41,9 +94,10 @@ def exists(self, key: Any) -> bool: _, expires_at = self._store[key] if expires_at is not None and expires_at < time.time(): - del self._store[key] + self.delete(key) return False + self._update_lru_on_access(key) return True async def async_get(self, key: Any) -> Any | None: diff --git a/tests/helper/test_inbox.py b/tests/helper/test_inbox.py index 54d7472..1abe5e5 100644 --- a/tests/helper/test_inbox.py +++ b/tests/helper/test_inbox.py @@ -6,6 +6,7 @@ from apsig import draft from cryptography.hazmat.primitives import serialization as crypto_serialization from cryptography.hazmat.primitives.asymmetric import rsa +import logging def _prepare_signed_request(): @@ -69,3 +70,18 @@ async def test_verify_draft_http_signature(): result = await inbox_verifier.verify(body, url, method, headers) assert result + +@pytest.mark.asyncio +async def test_verify_draft_http_signature_repeated(): + logging.getLogger("activitypub.server.inbox.helper").setLevel(logging.DEBUG) + + (body, url, method, headers) = _prepare_signed_request() + + config = AppConfig() + inbox_verifier = InboxVerifier(config) + + result = await inbox_verifier.verify(body, url, method, headers) + assert result + + result = await inbox_verifier.verify(body, url, method, headers) + assert result diff --git a/tests/helper/test_inbox_cache.py b/tests/helper/test_inbox_cache.py new file mode 100644 index 0000000..950e09b --- /dev/null +++ b/tests/helper/test_inbox_cache.py @@ -0,0 +1,85 @@ +import pytest +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa, ed25519, ec +from apkit.config import AppConfig +from apkit.kv.inmemory import InMemoryKV +from apkit.helper.inbox import InboxVerifier + +@pytest.fixture +def app_config(): + kv = InMemoryKV() + return AppConfig(kv=kv) + +@pytest.fixture +def verifier(app_config): + return InboxVerifier(app_config) + +@pytest.mark.asyncio +async def test_save_and_get_rsa_key_der(verifier, app_config): + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + public_key = private_key.public_key() + key_id = "https://example.com/actor#main-key" + + await verifier._InboxVerifier__save_signature_to_kv(key_id, public_key) + + stored_data = await app_config.kv.async_get(f"signature:{key_id}") + assert isinstance(stored_data, bytes) + + loaded_key_direct = serialization.load_der_public_key(stored_data) + assert isinstance(loaded_key_direct, rsa.RSAPublicKey) + assert loaded_key_direct.public_numbers() == public_key.public_numbers() + + retrieved_key, is_cache = await verifier._InboxVerifier__get_signature_from_kv(key_id) + + assert is_cache is True + assert isinstance(retrieved_key, rsa.RSAPublicKey) + assert retrieved_key.public_numbers() == public_key.public_numbers() + +@pytest.mark.asyncio +async def test_save_and_get_ed25519_key_der(verifier, app_config): + private_key = ed25519.Ed25519PrivateKey.generate() + public_key = private_key.public_key() + key_id = "https://example.com/actor#ed-key" + + await verifier._InboxVerifier__save_signature_to_kv(key_id, public_key) + + stored_data = await app_config.kv.async_get(f"signature:{key_id}") + assert isinstance(stored_data, bytes) + + loaded_key_direct = serialization.load_der_public_key(stored_data) + assert isinstance(loaded_key_direct, ed25519.Ed25519PublicKey) + + retrieved_key, is_cache = await verifier._InboxVerifier__get_signature_from_kv(key_id) + + assert is_cache is True + assert isinstance(retrieved_key, ed25519.Ed25519PublicKey) + assert retrieved_key.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw + ) == public_key.public_bytes( + encoding=serialization.Encoding.Raw, + format=serialization.PublicFormat.Raw + ) + +@pytest.mark.asyncio +async def test_get_signature_invalid_data(verifier, app_config): + key_id = "https://example.com/actor#invalid" + + await app_config.kv.async_set(f"signature:{key_id}", b"invalid-der-data") + + retrieved_key, is_cache = await verifier._InboxVerifier__get_signature_from_kv(key_id) + + assert retrieved_key is None + assert is_cache is False + +@pytest.mark.asyncio +async def test_get_signature_not_found(verifier): + key_id = "https://example.com/actor#missing" + + retrieved_key, is_cache = await verifier._InboxVerifier__get_signature_from_kv(key_id) + + assert retrieved_key is None + assert is_cache is False diff --git a/tests/test_cache.py b/tests/test_cache.py index b675959..05a834c 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -3,42 +3,13 @@ import pytest from apkit.kv import KeyValueStore +from apkit.kv.inmemory import InMemoryKV from apkit.cache import Cache -class FakeKeyValueStore(KeyValueStore[Any, Any]): - """Minimal in-memory KeyValueStore for tests.""" - - def __init__(self): - self._data = {} - - def get(self, key): - return self._data.get(key) - - def set(self, key, value, ttl_seconds = None): - self._data[key] = value - - def delete(self, key): - self._data.pop(key, None) - - def exists(self, key): - return key in self._data - - async def async_get(self, key): - return self.get(key) - - async def async_set(self, key, value, ttl_seconds = None): - self.set(key, value) - - async def async_delete(self, key): - self.delete(key) - - async def async_exists(self, key): - return self.exists(key) - @pytest.fixture def store(): - return FakeKeyValueStore() + return InMemoryKV() @pytest.fixture @@ -107,7 +78,7 @@ async def test_async_set_and_get(cache): @pytest.mark.asyncio -async def test_async_get_expired(cache, monkeypatch): +async def test_async_get_expired(cache: Cache, monkeypatch): monkeypatch.setattr(time, "time", lambda: 1000.0) await cache.async_set("a", "value", ttl=1) diff --git a/uv.lock b/uv.lock index 5f0a7e9..b092f0b 100644 --- a/uv.lock +++ b/uv.lock @@ -1651,18 +1651,18 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.49.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b4/19/8ad522587672c6bb013e284ee8a326136f6511c74784141f3fd550b99aee/pyrefly-0.49.0.tar.gz", hash = "sha256:d4e9a978d55253d2cd24c0354bd4cf087026d07bd374388c2ae12a3bc26f93fc", size = 4822135, upload-time = "2026-01-20T15:13:48.061Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5b/47/8c34be1fd5fb3ca74608a71dfece40c4b9d382a8899db8418be9b326ba3f/pyrefly-0.49.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:1cd5516ddab7c745e195fe1470629251962498482025bf2a9a9d53d5bde73729", size = 11644108, upload-time = "2026-01-20T15:13:25.358Z" }, - { url = "https://files.pythonhosted.org/packages/57/01/f492c92b4df963dbfda8d8e1cf57477704df8cdecf907568580af60193fe/pyrefly-0.49.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:5a998a37dc1465a648c03076545080a8bd2a421c67cac27686eca43244e8ac69", size = 11246465, upload-time = "2026-01-20T15:13:27.845Z" }, - { url = "https://files.pythonhosted.org/packages/d1/0b/89da00960e9c43ae7aa5f50886e9f87457137c444e513c00b714fdc6ba1e/pyrefly-0.49.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a96b1452fa61d7db6d5ae6b6297f50ba8c006ba7ce420233ebd33eaf95d04cfd", size = 31723528, upload-time = "2026-01-20T15:13:31.686Z" }, - { url = "https://files.pythonhosted.org/packages/f7/69/43a2a1a6bc00037879643d7d5257215fea1988dd2ef3168b5fe3cd55dcf0/pyrefly-0.49.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:97f1b5fb1be6f8f4868fe40e7ebeed055c8483012212267e182d58a8e50723e7", size = 33924099, upload-time = "2026-01-20T15:13:35.056Z" }, - { url = "https://files.pythonhosted.org/packages/f4/df/e475cd37d40221571e25465f0a39dd14123b8a3498f103e39e5938a2645f/pyrefly-0.49.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7ee11eefd1d551629ce1b25888814dbf758aac1a10279537d9425bc53f2d41c", size = 35026928, upload-time = "2026-01-20T15:13:38.403Z" }, - { url = "https://files.pythonhosted.org/packages/54/e2/fe9588b2cb4685c410ebf106bf1d28c66ed2727a5eeeabcfb51fec714143/pyrefly-0.49.0-py3-none-win32.whl", hash = "sha256:6196cb9b20ee977f64fa1fe87e06d3f7a222c5155031d21139fc60464a7a4b9c", size = 10675311, upload-time = "2026-01-20T15:13:40.99Z" }, - { url = "https://files.pythonhosted.org/packages/1a/dc/65fba26966bc2d9a9cbef620ef2a957f72bf3551822d6c250e3d36c2d0ee/pyrefly-0.49.0-py3-none-win_amd64.whl", hash = "sha256:15333b5550fd32a8f9a971ad124714d75f1906a67e48033dcc203258525bc7fd", size = 11418250, upload-time = "2026-01-20T15:13:43.321Z" }, - { url = "https://files.pythonhosted.org/packages/54/3c/9b0af11cbbfd57c5487af2d5d7322c30e7d73179171e1ffa4dda758dd286/pyrefly-0.49.0-py3-none-win_arm64.whl", hash = "sha256:4a57eebced37836791b681626a4be004ebd27221bc208f8200e1e2ca8a8b9510", size = 10962081, upload-time = "2026-01-20T15:13:45.82Z" }, +version = "0.50.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/fd/3de73c11f5f5f9bc493840d54bdac70c7ae7862f4afe3ad6c07b64e21917/pyrefly-0.50.0.tar.gz", hash = "sha256:55daafb02d8cfde54de5f6872a20059a9e34350bff47ec12b8b4f279eac3b8f5", size = 4890579, upload-time = "2026-01-26T21:04:12.475Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/bf/3a/a0267459efc61a7bb6e5281ab5a41c4a16a10dce8acbd7376f2956a59b2e/pyrefly-0.50.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:c997844857f72e9edf6365c05b58ac1b9176572b7d4a86153e95cebcf1b06dda", size = 11826217, upload-time = "2026-01-26T21:03:53.96Z" }, + { url = "https://files.pythonhosted.org/packages/f1/20/3bc1f05efabe36e0cfbce9cdd8043261e4237c3af0feabd60a985aad4645/pyrefly-0.50.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f630a47bfb65cf0baa094daee19c0d6c1ee18800b598353accca2a3bb347d65c", size = 11407127, upload-time = "2026-01-26T21:03:56.239Z" }, + { url = "https://files.pythonhosted.org/packages/c5/67/c161542c45d8f37666b8f55fcf5a096e9f90bef0682227f2713135e5ac5f/pyrefly-0.50.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e521bfbb730aa143e415457a4c11f9486ede5cd5f142b8b2446d4a6a1a22aef", size = 32317816, upload-time = "2026-01-26T21:03:58.73Z" }, + { url = "https://files.pythonhosted.org/packages/53/80/9887e4d3036184485a64b0353529d83938eefdc43ea60b9b5ce34ea782df/pyrefly-0.50.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:457b3c8267749fa82fe9555813c18707278fd3d11442aa3b85008b60c53fbfc1", size = 34569414, upload-time = "2026-01-26T21:04:01.211Z" }, + { url = "https://files.pythonhosted.org/packages/e7/07/95ebd93237ee646cc14a310380ec2a59fa8a87e5cefc91a832e902f88356/pyrefly-0.50.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92d6c908f63a9e484a3865f5995c0a9d4dd6f8e66aec6e911f2051b51d18c148", size = 35695334, upload-time = "2026-01-26T21:04:04.036Z" }, + { url = "https://files.pythonhosted.org/packages/c5/97/fc5f992a12713459c41124d7762df23ed9a78eb796a1adf7b1ea2c0b6104/pyrefly-0.50.0-py3-none-win32.whl", hash = "sha256:1ebbc5796b6d6b8b6937500c3c51ef22b4d607e5f100e170c104ea2832c22bbe", size = 10828039, upload-time = "2026-01-26T21:04:06.554Z" }, + { url = "https://files.pythonhosted.org/packages/02/fd/8aefef009268346b60cfa02c087efb8a587cf4bdc630ce5a072c59a765e4/pyrefly-0.50.0-py3-none-win_amd64.whl", hash = "sha256:dae33a7023fd85acbf8ba8b4d8488bc897e92f7439016db10d8e38c3de21ba30", size = 11585740, upload-time = "2026-01-26T21:04:08.558Z" }, + { url = "https://files.pythonhosted.org/packages/3e/8b/4ffcab526a92611b3d5c9ca3eab8d98b6a935ee11e58ee7cdbe9499bd1d9/pyrefly-0.50.0-py3-none-win_arm64.whl", hash = "sha256:7ce692c8262ef9bc877b735e6b4ec053dac119ed64d4cad51aa9d8c285cfb549", size = 11119646, upload-time = "2026-01-26T21:04:10.515Z" }, ] [[package]]