diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..19bdef3 --- /dev/null +++ b/src/config.py @@ -0,0 +1,22 @@ +# src/config.py +from pydantic import BaseSettings +from typing import Optional +from dotenv import load_dotenv +import os + +# Load .env in development only +load_dotenv(override=False) + +class Settings(BaseSettings): + # Expect full PEM strings (including header/footer) + PRIVATE_KEY_PEM: Optional[str] = None + PUBLIC_KEY_PEM: Optional[str] = None + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + # prevent pydantic from ever printing values in exceptions + keep_untouched = (os.environ,) + +# global settings instance +settings = Settings() diff --git a/src/key_manager.py b/src/key_manager.py new file mode 100644 index 0000000..082e60f --- /dev/null +++ b/src/key_manager.py @@ -0,0 +1,75 @@ +# src/key_manager.py +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import ed25519, rsa +from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key +from cryptography.exceptions import UnsupportedAlgorithm +from typing import Optional, Tuple +from .config import settings +import hashlib +import logging + +logger = logging.getLogger(__name__) + +class KeyLoadError(Exception): + pass + +def _sha256_fingerprint(data: bytes) -> str: + """Return a short safe fingerprint string for logging (hex).""" + h = hashlib.sha256(data).hexdigest() + # return short fingerprint to avoid leaking too much + return h[:16] + +def _to_bytes(pem_str: str) -> bytes: + if isinstance(pem_str, str): + return pem_str.encode("utf-8") + return pem_str + +def load_private_key_from_env() -> Optional[serialization.PrivateFormat]: + pem = settings.PRIVATE_KEY_PEM + if not pem: + logger.debug("No PRIVATE_KEY_PEM provided in environment.") + return None + try: + pem_bytes = _to_bytes(pem) + # pass password if using encrypted PEM in production (not covered here) + key = load_pem_private_key(pem_bytes, password=None) + # log only fingerprint + try: + pub_bytes = key.public_key().public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo + ) + logger.info("Private key loaded, public fingerprint=%s", _sha256_fingerprint(pub_bytes)) + except Exception: + logger.info("Private key loaded (public fingerprint unavailable).") + return key + except (ValueError, TypeError, UnsupportedAlgorithm) as e: + # sanitize error so we don't leak key contents + logger.exception("Failed loading private key from environment (sanitized).") + raise KeyLoadError("Invalid private key in environment") from e + +def load_public_key_from_env() -> Optional[serialization.PublicFormat]: + pem = settings.PUBLIC_KEY_PEM + if not pem: + logger.debug("No PUBLIC_KEY_PEM provided in environment.") + return None + try: + pem_bytes = _to_bytes(pem) + key = load_pem_public_key(pem_bytes) + # log only fingerprint + try: + pub_bytes = key.public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.SubjectPublicKeyInfo + ) + logger.info("Public key loaded, fingerprint=%s", _sha256_fingerprint(pub_bytes)) + except Exception: + logger.info("Public key loaded.") + return key + except (ValueError, TypeError, UnsupportedAlgorithm) as e: + logger.exception("Failed loading public key from environment (sanitized).") + raise KeyLoadError("Invalid public key in environment") from e + +# Single place to fetch keys — you can cache them if needed +PRIVATE_KEY = load_private_key_from_env() +PUBLIC_KEY = load_public_key_from_env() diff --git a/src/manager.py b/src/manager.py new file mode 100644 index 0000000..28e0969 --- /dev/null +++ b/src/manager.py @@ -0,0 +1,52 @@ +# app/manager.py +import asyncio +from starlette.websockets import WebSocket +from typing import Set +import logging + +logger = logging.getLogger("ticket_scans.manager") + +class TicketScanManager: + def __init__(self): + # set of active WebSocket connections + self.active_connections: Set[WebSocket] = set() + self._lock = asyncio.Lock() + + async def connect(self, websocket: WebSocket): + await websocket.accept() + async with self._lock: + self.active_connections.add(websocket) + logger.info("WebSocket connected. total connections=%d", len(self.active_connections)) + + async def disconnect(self, websocket: WebSocket): + async with self._lock: + if websocket in self.active_connections: + self.active_connections.remove(websocket) + logger.info("WebSocket disconnected. total connections=%d", len(self.active_connections)) + + async def broadcast_scan(self, scan_payload: dict): + """ + Broadcasts a scan payload (dict) to all connected clients. + This awaits send_json on each WebSocket. It removes any dead websockets. + """ + async with self._lock: + connections = list(self.active_connections) + + if not connections: + logger.info("Broadcast called but no active connections.") + return + + logger.info("Broadcasting scan to %d connection(s).", len(connections)) + to_remove = [] + for ws in connections: + try: + await ws.send_json(scan_payload) + except Exception as e: + logger.exception("Error sending to websocket; scheduling removal. Error: %s", e) + to_remove.append(ws) + + if to_remove: + async with self._lock: + for ws in to_remove: + self.active_connections.discard(ws) + logger.info("Removed %d dead connection(s) after broadcast.", len(to_remove)) diff --git a/src/schemas.py b/src/schemas.py new file mode 100644 index 0000000..44919ed --- /dev/null +++ b/src/schemas.py @@ -0,0 +1,11 @@ +# app/schemas.py +from pydantic import BaseModel +from typing import Optional +from datetime import datetime + +class TicketScan(BaseModel): + ticket_id: str + event_id: str + scanner_id: Optional[str] = None + timestamp: datetime + meta: Optional[dict] = None diff --git a/src/signer.py b/src/signer.py new file mode 100644 index 0000000..275fbe9 --- /dev/null +++ b/src/signer.py @@ -0,0 +1,84 @@ +# src/signer.py +from .key_manager import PRIVATE_KEY, PUBLIC_KEY, _sha256_fingerprint +from cryptography.hazmat.primitives.asymmetric import ed25519, padding +from cryptography.hazmat.primitives import hashes +import base64 +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + +def _b64u_encode(b: bytes) -> str: + return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") + +def _b64u_decode(s: str) -> bytes: + padding_needed = (-len(s)) % 4 + s_padded = s + ("=" * padding_needed) + return base64.urlsafe_b64decode(s_padded.encode("ascii")) + +def sign(payload: bytes, private_key=None) -> str: + """Sign payload bytes and return base64url signature string. + Will use PRIVATE_KEY from key_manager by default. + """ + key = private_key or PRIVATE_KEY + if key is None: + raise RuntimeError("No private key available for signing (service misconfigured).") + # Ed25519 + if isinstance(key, ed25519.Ed25519PrivateKey): + sig = key.sign(payload) + logger.debug("Signed payload with Ed25519 key (pub fingerprint=%s)", + _safe_pub_fingerprint(key)) + return _b64u_encode(sig) + + # RSA fallback (PKCS#1 v1.5 + SHA256) + try: + sig = key.sign( + payload, + padding.PKCS1v15(), + hashes.SHA256() + ) + logger.debug("Signed payload with RSA key (pub fingerprint=%s)", _safe_pub_fingerprint(key)) + return _b64u_encode(sig) + except Exception as e: + logger.exception("Signing failed (sanitized).") + raise + +def verify(payload: bytes, signature_b64u: str, public_key=None) -> bool: + key = public_key or PUBLIC_KEY + if key is None: + raise RuntimeError("No public key available for verification (service misconfigured).") + sig = _b64u_decode(signature_b64u) + # Ed25519 + if isinstance(key, ed25519.Ed25519PublicKey): + try: + key.verify(sig, payload) + logger.debug("Ed25519 verification success (pub fingerprint=%s)", _safe_pub_fingerprint(key)) + return True + except Exception: + logger.debug("Ed25519 verification failed (pub fingerprint=%s)", _safe_pub_fingerprint(key)) + return False + + # RSA fallback + try: + key.verify( + sig, + payload, + padding.PKCS1v15(), + hashes.SHA256() + ) + logger.debug("RSA verification success (pub fingerprint=%s)", _safe_pub_fingerprint(key)) + return True + except Exception: + logger.debug("RSA verification failed (pub fingerprint=%s)", _safe_pub_fingerprint(key)) + return False + +def _safe_pub_fingerprint(key) -> str: + """Return a short public fingerprint for safe logging. If obtaining bytes fails, return 'unknown'.""" + try: + pub_bytes = key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ) + return _sha256_fingerprint(pub_bytes) + except Exception: + return "unknown" diff --git a/src/websocket.py b/src/websocket.py new file mode 100644 index 0000000..35d5276 --- /dev/null +++ b/src/websocket.py @@ -0,0 +1,62 @@ +# app/main.py +import logging +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, APIRouter +from fastapi.responses import JSONResponse +from app.manager import TicketScanManager, logger as manager_logger +from app.schemas import TicketScan +from datetime import datetime + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s" +) +logger = logging.getLogger("ticket_scans.app") +# Keep manager logger consistent +manager_logger.setLevel(logging.INFO) + +app = FastAPI(title="Ticket Scans WebSocket Service") +router = APIRouter() +manager = TicketScanManager() + +@router.websocket("/ws/ticket-scans") +async def websocket_ticket_scans(ws: WebSocket): + """ + WebSocket endpoint that keeps connection open and sends scans when they are broadcast. + """ + await manager.connect(ws) + try: + # Keep the connection alive; optionally handle incoming messages if needed. + while True: + # Wait for any message from client; if you don't expect messages you can await ws.receive_text() + # but we will use receive to detect disconnects from client side. + try: + data = await ws.receive_text() + # For now, we simply ignore messages from clients but log them + logger.info("Received message from client (ignored): %s", data) + except Exception: + # The client may close the connection — break to disconnect and cleanup + break + except WebSocketDisconnect: + logger.info("Client disconnected via WebSocketDisconnect.") + except Exception as e: + logger.exception("Unexpected error in websocket loop: %s", e) + finally: + await manager.disconnect(ws) + +@router.post("/scans", response_class=JSONResponse) +async def post_scan(scan: TicketScan): + """ + POST endpoint to accept a scan and broadcast it to clients. + In production, scanning devices/services would typically call this API when a ticket is scanned, + or you would call manager.broadcast_scan from inside your event pipeline. + """ + payload = scan.dict() + # Optionally add server-received timestamp + payload.setdefault("server_received_at", datetime.utcnow().isoformat()) + # Broadcast but don't block the response when there are many clients (we await because manager.broadcast_scan is async) + await manager.broadcast_scan(payload) + logger.info("Received scan for ticket_id=%s event_id=%s", scan.ticket_id, scan.event_id) + return {"ok": True} + +app.include_router(router) diff --git a/tests/test_sign_verify.py b/tests/test_sign_verify.py new file mode 100644 index 0000000..2e111f9 --- /dev/null +++ b/tests/test_sign_verify.py @@ -0,0 +1,37 @@ +# tests/test_sign_verify.py +import tempfile +import os +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.hazmat.primitives import serialization +from src.signer import sign, verify +from importlib import reload +import src.key_manager as key_manager + +def test_ed25519_sign_verify_ephemeral(monkeypatch): + # generate ephemeral keys + priv = ed25519.Ed25519PrivateKey.generate() + pub = priv.public_key() + priv_pem = priv.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ).decode("utf-8") + pub_pem = pub.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo + ).decode("utf-8") + + # monkeypatch environment variables for the config loader + monkeypatch.setenv("PRIVATE_KEY_PEM", priv_pem) + monkeypatch.setenv("PUBLIC_KEY_PEM", pub_pem) + + # reload modules to pick up monkeypatched env + reload(key_manager) + from src import signer as signer_module + + payload = b"unit test payload" + signature = signer_module.sign(payload) + assert signer_module.verify(payload, signature) is True + + # tamper payload -> verify should fail + assert signer_module.verify(payload + b"x", signature) is False diff --git a/tests/test_websocket.py b/tests/test_websocket.py new file mode 100644 index 0000000..2629914 --- /dev/null +++ b/tests/test_websocket.py @@ -0,0 +1,68 @@ +# tests/test_websocket.py +import pytest +from fastapi.testclient import TestClient +from app.main import app, manager +from datetime import datetime +import time +import threading +import json +import logging + +@pytest.fixture +def client(): + return TestClient(app) + +def _trigger_scan_via_post(client, scan): + client.post("/scans", json=scan) + +def test_websocket_receives_broadcast(client): + scan = { + "ticket_id": "TICKET-123", + "event_id": "EVENT-1", + "scanner_id": "GATE-A", + "timestamp": datetime.utcnow().isoformat(), + "meta": {"seat": "A1"} + } + + with client.websocket_connect("/ws/ticket-scans") as websocket: + # Start a thread that calls POST /scans after a short delay, + # to simulate an external scan arriving while WS is open. + t = threading.Timer(0.1, _trigger_scan_via_post, args=(client, scan)) + t.start() + + # Receive JSON message from websocket; timeout will raise if not received + data = websocket.receive_json(timeout=5) + # Verify payload contains expected fields + assert data["ticket_id"] == scan["ticket_id"] + assert data["event_id"] == scan["event_id"] + assert data["meta"]["seat"] == "A1" + t.cancel() + +def test_connection_logging_is_emitted(client, caplog): + caplog.set_level(logging.INFO) + scan = { + "ticket_id": "TICKET-888", + "event_id": "EVENT-LOG", + "timestamp": datetime.utcnow().isoformat(), + } + + with client.websocket_connect("/ws/ticket-scans"): + # When connection established, manager logs a connect message + # Give small time for logger to emit + time.sleep(0.05) + # Ensure connect logged + found_connect = any("WebSocket connected" in rec.message for rec in caplog.records) + assert found_connect, "connect log not found; logs: %s" % [r.message for r in caplog.records] + + # After context manager exits, disconnect log should exist + found_disconnect = any("WebSocket disconnected" in rec.message for rec in caplog.records) + assert found_disconnect, "disconnect log not found; logs: %s" % [r.message for r in caplog.records] + + # Now connect again and trigger a scan; check broadcast log exists + with client.websocket_connect("/ws/ticket-scans"): + # trigger a scan + client.post("/scans", json=scan) + time.sleep(0.05) + + found_broadcast = any("Broadcasting scan" in rec.message or "Received scan for ticket_id" in rec.message for rec in caplog.records) + assert found_broadcast, "broadcast not logged; logs: %s" % [r.message for r in caplog.records]