From 97fc9d5f58146c3f91815cefc3266b9fe5353633 Mon Sep 17 00:00:00 2001 From: Walid Ladeb Date: Wed, 18 Mar 2026 23:22:24 +0100 Subject: [PATCH] fix(identity): atomic write + safe Windows file locking --- sdk/python/aaip/exceptions.py | 5 + sdk/python/aaip/identity/__init__.py | 200 ++++++++++++++++++++++++--- 2 files changed, 188 insertions(+), 17 deletions(-) diff --git a/sdk/python/aaip/exceptions.py b/sdk/python/aaip/exceptions.py index f3dfed1..faa84a5 100644 --- a/sdk/python/aaip/exceptions.py +++ b/sdk/python/aaip/exceptions.py @@ -10,4 +10,9 @@ class AAIPError(Exception): class IdentityDecryptionError(AAIPError): """Raised when identity decryption fails (e.g., wrong passphrase).""" + pass + + +class IdentityCorruptedError(AAIPError): + """Raised when identity file is corrupted or unreadable.""" pass \ No newline at end of file diff --git a/sdk/python/aaip/identity/__init__.py b/sdk/python/aaip/identity/__init__.py index abcff8f..ea617ec 100644 --- a/sdk/python/aaip/identity/__init__.py +++ b/sdk/python/aaip/identity/__init__.py @@ -21,6 +21,8 @@ import logging import os import secrets +import sys +import tempfile import time from pathlib import Path @@ -42,6 +44,72 @@ def _has_cryptography() -> bool: HAS_CRYPTOGRAPHY = _has_cryptography() +# --------------------------------------------------------------------------- +# Atomic file operations and file locking +# --------------------------------------------------------------------------- + +def _atomic_write(path: Path, content: str) -> None: + """ + Write content to path atomically using tempfile + os.replace. + + This ensures that concurrent processes never see a partially-written file. + """ + parent = path.parent + parent.mkdir(parents=True, exist_ok=True) + fd, tmp_path = tempfile.mkstemp(dir=str(parent), prefix=".aaip-tmp-") + try: + with os.fdopen(fd, 'w', encoding='utf-8') as f: + f.write(content) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, str(path)) + except Exception: + # Clean up temp file on failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + + +def _acquire_lock(file_obj, exclusive: bool = True) -> None: + """ + Acquire advisory file lock. + + Uses fcntl.flock on Unix and msvcrt.locking on Windows. + """ + if sys.platform == "win32": + import msvcrt + pos = file_obj.tell() + try: + file_obj.seek(0) + mode = msvcrt.LK_NBLCK if exclusive else msvcrt.LK_NBRLCK + msvcrt.locking(file_obj.fileno(), mode, 1) + finally: + file_obj.seek(pos) + else: + import fcntl + mode = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH + fcntl.flock(file_obj.fileno(), mode) + + +def _release_lock(file_obj) -> None: + """ + Release advisory file lock. + """ + if sys.platform == "win32": + import msvcrt + pos = file_obj.tell() + try: + file_obj.seek(0) + msvcrt.locking(file_obj.fileno(), msvcrt.LK_UNLCK, 1) + finally: + file_obj.seek(pos) + else: + import fcntl + fcntl.flock(file_obj.fileno(), fcntl.LOCK_UN) + + # --------------------------------------------------------------------------- # AgentIdentity # --------------------------------------------------------------------------- @@ -80,13 +148,30 @@ def generate(cls) -> "AgentIdentity": @classmethod def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity": + # Check AAIP_IDENTITY_PATH env var for custom path + env_path = os.environ.get("AAIP_IDENTITY_PATH") + if env_path: + path = env_path p = Path(path) if p.exists(): + # Use file locking for safe concurrent reads try: - d = json.loads(p.read_text()) + with open(p, 'r', encoding='utf-8') as f: + _acquire_lock(f, exclusive=False) + try: + content = f.read() + finally: + _release_lock(f) + except OSError as e: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Cannot read identity file: {e}" + ) from e + try: + d = json.loads(content) except json.JSONDecodeError as e: - from ..exceptions import IdentityDecryptionError - raise IdentityDecryptionError( + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( f"Identity file contains invalid JSON: {e}" ) from e # Check if identity is encrypted @@ -102,8 +187,8 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity": required: tuple[str, ...] = ("public_key_hex",) for field in required: if field not in d: - from ..exceptions import IdentityDecryptionError - raise IdentityDecryptionError( + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( f"Encrypted identity missing required field: {field}" ) # Decrypt the seed @@ -118,8 +203,8 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity": try: pub = bytes.fromhex(d["public_key_hex"]) except ValueError as e: - from ..exceptions import IdentityDecryptionError - raise IdentityDecryptionError( + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( f"Invalid hex in public_key_hex: {e}" ) from e identity = cls(seed, pub) @@ -130,22 +215,22 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity": required = ("private_key_hex", "public_key_hex") for field in required: if field not in d: - from ..exceptions import IdentityDecryptionError - raise IdentityDecryptionError( + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( f"Plaintext identity missing required field: {field}" ) try: seed = bytes.fromhex(d["private_key_hex"]) except ValueError as e: - from ..exceptions import IdentityDecryptionError - raise IdentityDecryptionError( + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( f"Invalid hex in private_key_hex: {e}" ) from e try: pub = bytes.fromhex(d["public_key_hex"]) except ValueError as e: - from ..exceptions import IdentityDecryptionError - raise IdentityDecryptionError( + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( f"Invalid hex in public_key_hex: {e}" ) from e identity = cls(seed, pub) @@ -161,9 +246,90 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity": "Set AAIP_IDENTITY_PASSPHRASE for production security." ) return identity - identity = cls.generate() - identity.save(path) - return identity + # Create new identity with exclusive lock to prevent race conditions + # Double-check file existence after acquiring lock + p.parent.mkdir(parents=True, exist_ok=True) + try: + with open(p, 'a+') as f: + _acquire_lock(f, exclusive=True) + try: + f.seek(0) + content = f.read() + if content: + # File was created by another process while we waited + try: + d = json.loads(content) + except json.JSONDecodeError as e: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Identity file contains invalid JSON: {e}" + ) from e + # Process the existing file + if "private_key_encrypted" in d: + passphrase = os.environ.get("AAIP_IDENTITY_PASSPHRASE") + if not passphrase or passphrase.strip() == "": + from ..exceptions import IdentityDecryptionError + raise IdentityDecryptionError( + "Identity is encrypted but AAIP_IDENTITY_PASSPHRASE is not set." + ) + required = ("public_key_hex",) + for field in required: + if field not in d: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Encrypted identity missing required field: {field}" + ) + from ._encryption import decrypt_seed + try: + seed = decrypt_seed(d, passphrase) + except Exception as e: + from ..exceptions import IdentityDecryptionError + raise IdentityDecryptionError( + f"Decryption failed: {e}" + ) from e + try: + pub = bytes.fromhex(d["public_key_hex"]) + except ValueError as e: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Invalid hex in public_key_hex: {e}" + ) from e + return cls(seed, pub) + else: + required = ("private_key_hex", "public_key_hex") + for field in required: + if field not in d: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Plaintext identity missing required field: {field}" + ) + try: + seed = bytes.fromhex(d["private_key_hex"]) + except ValueError as e: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Invalid hex in private_key_hex: {e}" + ) from e + try: + pub = bytes.fromhex(d["public_key_hex"]) + except ValueError as e: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Invalid hex in public_key_hex: {e}" + ) from e + return cls(seed, pub) + else: + # File is empty, create new identity + identity = cls.generate() + identity.save(path) + return identity + finally: + _release_lock(f) + except OSError as e: + from ..exceptions import IdentityCorruptedError + raise IdentityCorruptedError( + f"Cannot create identity file: {e}" + ) from e def save(self, path: str = IDENTITY_FILE) -> None: passphrase = os.environ.get("AAIP_IDENTITY_PASSPHRASE") @@ -192,7 +358,7 @@ def save(self, path: str = IDENTITY_FILE) -> None: "Private key stored without encryption. " "Set AAIP_IDENTITY_PASSPHRASE for production security." ) - Path(path).write_text(json.dumps(data, indent=2)) + _atomic_write(Path(path), json.dumps(data, indent=2)) # ── sign / verify ────────────────────────────────────────────────