Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/python/aaip/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,9 @@ class AAIPError(Exception):

class IdentityDecryptionError(AAIPError):
"""Raised when identity decryption fails (e.g., wrong passphrase)."""
pass


class IdentityCorruptedError(AAIPError):
"""Raised when identity file is corrupted or unreadable."""
pass
200 changes: 183 additions & 17 deletions sdk/python/aaip/identity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import logging
import os
import secrets
import sys
import tempfile
import time
from pathlib import Path

Expand All @@ -42,6 +44,72 @@ def _has_cryptography() -> bool:
HAS_CRYPTOGRAPHY = _has_cryptography()


# ---------------------------------------------------------------------------
# Atomic file operations and file locking
# ---------------------------------------------------------------------------

def _atomic_write(path: Path, content: str) -> None:
"""
Write content to path atomically using tempfile + os.replace.

This ensures that concurrent processes never see a partially-written file.
"""
parent = path.parent
parent.mkdir(parents=True, exist_ok=True)
fd, tmp_path = tempfile.mkstemp(dir=str(parent), prefix=".aaip-tmp-")
try:
with os.fdopen(fd, 'w', encoding='utf-8') as f:
f.write(content)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, str(path))
except Exception:
# Clean up temp file on failure
try:
os.unlink(tmp_path)
except OSError:
pass
raise


def _acquire_lock(file_obj, exclusive: bool = True) -> None:
"""
Acquire advisory file lock.

Uses fcntl.flock on Unix and msvcrt.locking on Windows.
"""
if sys.platform == "win32":
import msvcrt
pos = file_obj.tell()
try:
file_obj.seek(0)
mode = msvcrt.LK_NBLCK if exclusive else msvcrt.LK_NBRLCK
msvcrt.locking(file_obj.fileno(), mode, 1)
finally:
file_obj.seek(pos)
else:
import fcntl
mode = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH
fcntl.flock(file_obj.fileno(), mode)


def _release_lock(file_obj) -> None:
"""
Release advisory file lock.
"""
if sys.platform == "win32":
import msvcrt
pos = file_obj.tell()
try:
file_obj.seek(0)
msvcrt.locking(file_obj.fileno(), msvcrt.LK_UNLCK, 1)
finally:
file_obj.seek(pos)
else:
import fcntl
fcntl.flock(file_obj.fileno(), fcntl.LOCK_UN)


# ---------------------------------------------------------------------------
# AgentIdentity
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -80,13 +148,30 @@ def generate(cls) -> "AgentIdentity":

@classmethod
def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity":
# Check AAIP_IDENTITY_PATH env var for custom path
env_path = os.environ.get("AAIP_IDENTITY_PATH")
if env_path:
path = env_path
p = Path(path)
if p.exists():
# Use file locking for safe concurrent reads
try:
d = json.loads(p.read_text())
with open(p, 'r', encoding='utf-8') as f:
_acquire_lock(f, exclusive=False)
try:
content = f.read()
finally:
_release_lock(f)
except OSError as e:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Cannot read identity file: {e}"
) from e
try:
d = json.loads(content)
except json.JSONDecodeError as e:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Identity file contains invalid JSON: {e}"
) from e
# Check if identity is encrypted
Expand All @@ -102,8 +187,8 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity":
required: tuple[str, ...] = ("public_key_hex",)
for field in required:
if field not in d:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Encrypted identity missing required field: {field}"
)
# Decrypt the seed
Expand All @@ -118,8 +203,8 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity":
try:
pub = bytes.fromhex(d["public_key_hex"])
except ValueError as e:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Invalid hex in public_key_hex: {e}"
) from e
identity = cls(seed, pub)
Expand All @@ -130,22 +215,22 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity":
required = ("private_key_hex", "public_key_hex")
for field in required:
if field not in d:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Plaintext identity missing required field: {field}"
)
try:
seed = bytes.fromhex(d["private_key_hex"])
except ValueError as e:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Invalid hex in private_key_hex: {e}"
) from e
try:
pub = bytes.fromhex(d["public_key_hex"])
except ValueError as e:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Invalid hex in public_key_hex: {e}"
) from e
identity = cls(seed, pub)
Expand All @@ -161,9 +246,90 @@ def load_or_create(cls, path: str = IDENTITY_FILE) -> "AgentIdentity":
"Set AAIP_IDENTITY_PASSPHRASE for production security."
)
return identity
identity = cls.generate()
identity.save(path)
return identity
# Create new identity with exclusive lock to prevent race conditions
# Double-check file existence after acquiring lock
p.parent.mkdir(parents=True, exist_ok=True)
try:
with open(p, 'a+') as f:
_acquire_lock(f, exclusive=True)
try:
f.seek(0)
content = f.read()
if content:
# File was created by another process while we waited
try:
d = json.loads(content)
except json.JSONDecodeError as e:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Identity file contains invalid JSON: {e}"
) from e
# Process the existing file
if "private_key_encrypted" in d:
passphrase = os.environ.get("AAIP_IDENTITY_PASSPHRASE")
if not passphrase or passphrase.strip() == "":
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
"Identity is encrypted but AAIP_IDENTITY_PASSPHRASE is not set."
)
required = ("public_key_hex",)
for field in required:
if field not in d:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Encrypted identity missing required field: {field}"
)
from ._encryption import decrypt_seed
try:
seed = decrypt_seed(d, passphrase)
except Exception as e:
from ..exceptions import IdentityDecryptionError
raise IdentityDecryptionError(
f"Decryption failed: {e}"
) from e
try:
pub = bytes.fromhex(d["public_key_hex"])
except ValueError as e:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Invalid hex in public_key_hex: {e}"
) from e
return cls(seed, pub)
else:
required = ("private_key_hex", "public_key_hex")
for field in required:
if field not in d:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Plaintext identity missing required field: {field}"
)
try:
seed = bytes.fromhex(d["private_key_hex"])
except ValueError as e:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Invalid hex in private_key_hex: {e}"
) from e
try:
pub = bytes.fromhex(d["public_key_hex"])
except ValueError as e:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Invalid hex in public_key_hex: {e}"
) from e
return cls(seed, pub)
else:
# File is empty, create new identity
identity = cls.generate()
identity.save(path)
return identity
finally:
_release_lock(f)
except OSError as e:
from ..exceptions import IdentityCorruptedError
raise IdentityCorruptedError(
f"Cannot create identity file: {e}"
) from e

def save(self, path: str = IDENTITY_FILE) -> None:
passphrase = os.environ.get("AAIP_IDENTITY_PASSPHRASE")
Expand Down Expand Up @@ -192,7 +358,7 @@ def save(self, path: str = IDENTITY_FILE) -> None:
"Private key stored without encryption. "
"Set AAIP_IDENTITY_PASSPHRASE for production security."
)
Path(path).write_text(json.dumps(data, indent=2))
_atomic_write(Path(path), json.dumps(data, indent=2))

# ── sign / verify ────────────────────────────────────────────────

Expand Down
Loading