Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# src/config.py
from pydantic import BaseSettings
from typing import Optional
from dotenv import load_dotenv
import os

# Load .env in development only
load_dotenv(override=False)

class Settings(BaseSettings):
# Expect full PEM strings (including header/footer)
PRIVATE_KEY_PEM: Optional[str] = None
PUBLIC_KEY_PEM: Optional[str] = None

class Config:
env_file = ".env"
env_file_encoding = "utf-8"
# prevent pydantic from ever printing values in exceptions
keep_untouched = (os.environ,)

# global settings instance
settings = Settings()
75 changes: 75 additions & 0 deletions src/key_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# src/key_manager.py
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import ed25519, rsa
from cryptography.hazmat.primitives.serialization import load_pem_private_key, load_pem_public_key
from cryptography.exceptions import UnsupportedAlgorithm
from typing import Optional, Tuple
from .config import settings
import hashlib
import logging

logger = logging.getLogger(__name__)

class KeyLoadError(Exception):
pass

def _sha256_fingerprint(data: bytes) -> str:
"""Return a short safe fingerprint string for logging (hex)."""
h = hashlib.sha256(data).hexdigest()
# return short fingerprint to avoid leaking too much
return h[:16]

def _to_bytes(pem_str: str) -> bytes:
if isinstance(pem_str, str):
return pem_str.encode("utf-8")
return pem_str

def load_private_key_from_env() -> Optional[serialization.PrivateFormat]:
pem = settings.PRIVATE_KEY_PEM
if not pem:
logger.debug("No PRIVATE_KEY_PEM provided in environment.")
return None
try:
pem_bytes = _to_bytes(pem)
# pass password if using encrypted PEM in production (not covered here)
key = load_pem_private_key(pem_bytes, password=None)
# log only fingerprint
try:
pub_bytes = key.public_key().public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
)
logger.info("Private key loaded, public fingerprint=%s", _sha256_fingerprint(pub_bytes))
except Exception:
logger.info("Private key loaded (public fingerprint unavailable).")
return key
except (ValueError, TypeError, UnsupportedAlgorithm) as e:
# sanitize error so we don't leak key contents
logger.exception("Failed loading private key from environment (sanitized).")
raise KeyLoadError("Invalid private key in environment") from e

def load_public_key_from_env() -> Optional[serialization.PublicFormat]:
pem = settings.PUBLIC_KEY_PEM
if not pem:
logger.debug("No PUBLIC_KEY_PEM provided in environment.")
return None
try:
pem_bytes = _to_bytes(pem)
key = load_pem_public_key(pem_bytes)
# log only fingerprint
try:
pub_bytes = key.public_bytes(
serialization.Encoding.PEM,
serialization.PublicFormat.SubjectPublicKeyInfo
)
logger.info("Public key loaded, fingerprint=%s", _sha256_fingerprint(pub_bytes))
except Exception:
logger.info("Public key loaded.")
return key
except (ValueError, TypeError, UnsupportedAlgorithm) as e:
logger.exception("Failed loading public key from environment (sanitized).")
raise KeyLoadError("Invalid public key in environment") from e

# Single place to fetch keys — you can cache them if needed
PRIVATE_KEY = load_private_key_from_env()
PUBLIC_KEY = load_public_key_from_env()
52 changes: 52 additions & 0 deletions src/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# app/manager.py
import asyncio
from starlette.websockets import WebSocket
from typing import Set
import logging

logger = logging.getLogger("ticket_scans.manager")

class TicketScanManager:
def __init__(self):
# set of active WebSocket connections
self.active_connections: Set[WebSocket] = set()
self._lock = asyncio.Lock()

async def connect(self, websocket: WebSocket):
await websocket.accept()
async with self._lock:
self.active_connections.add(websocket)
logger.info("WebSocket connected. total connections=%d", len(self.active_connections))

async def disconnect(self, websocket: WebSocket):
async with self._lock:
if websocket in self.active_connections:
self.active_connections.remove(websocket)
logger.info("WebSocket disconnected. total connections=%d", len(self.active_connections))

async def broadcast_scan(self, scan_payload: dict):
"""
Broadcasts a scan payload (dict) to all connected clients.
This awaits send_json on each WebSocket. It removes any dead websockets.
"""
async with self._lock:
connections = list(self.active_connections)

if not connections:
logger.info("Broadcast called but no active connections.")
return

logger.info("Broadcasting scan to %d connection(s).", len(connections))
to_remove = []
for ws in connections:
try:
await ws.send_json(scan_payload)
except Exception as e:
logger.exception("Error sending to websocket; scheduling removal. Error: %s", e)
to_remove.append(ws)

if to_remove:
async with self._lock:
for ws in to_remove:
self.active_connections.discard(ws)
logger.info("Removed %d dead connection(s) after broadcast.", len(to_remove))
11 changes: 11 additions & 0 deletions src/schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# app/schemas.py
from pydantic import BaseModel
from typing import Optional
from datetime import datetime

class TicketScan(BaseModel):
ticket_id: str
event_id: str
scanner_id: Optional[str] = None
timestamp: datetime
meta: Optional[dict] = None
84 changes: 84 additions & 0 deletions src/signer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# src/signer.py
from .key_manager import PRIVATE_KEY, PUBLIC_KEY, _sha256_fingerprint
from cryptography.hazmat.primitives.asymmetric import ed25519, padding
from cryptography.hazmat.primitives import hashes
import base64
import logging
from typing import Optional

logger = logging.getLogger(__name__)

def _b64u_encode(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii")

def _b64u_decode(s: str) -> bytes:
padding_needed = (-len(s)) % 4
s_padded = s + ("=" * padding_needed)
return base64.urlsafe_b64decode(s_padded.encode("ascii"))

def sign(payload: bytes, private_key=None) -> str:
"""Sign payload bytes and return base64url signature string.
Will use PRIVATE_KEY from key_manager by default.
"""
key = private_key or PRIVATE_KEY
if key is None:
raise RuntimeError("No private key available for signing (service misconfigured).")
# Ed25519
if isinstance(key, ed25519.Ed25519PrivateKey):
sig = key.sign(payload)
logger.debug("Signed payload with Ed25519 key (pub fingerprint=%s)",
_safe_pub_fingerprint(key))
return _b64u_encode(sig)

# RSA fallback (PKCS#1 v1.5 + SHA256)
try:
sig = key.sign(
payload,
padding.PKCS1v15(),
hashes.SHA256()
)
logger.debug("Signed payload with RSA key (pub fingerprint=%s)", _safe_pub_fingerprint(key))
return _b64u_encode(sig)
except Exception as e:
logger.exception("Signing failed (sanitized).")
raise

def verify(payload: bytes, signature_b64u: str, public_key=None) -> bool:
key = public_key or PUBLIC_KEY
if key is None:
raise RuntimeError("No public key available for verification (service misconfigured).")
sig = _b64u_decode(signature_b64u)
# Ed25519
if isinstance(key, ed25519.Ed25519PublicKey):
try:
key.verify(sig, payload)
logger.debug("Ed25519 verification success (pub fingerprint=%s)", _safe_pub_fingerprint(key))
return True
except Exception:
logger.debug("Ed25519 verification failed (pub fingerprint=%s)", _safe_pub_fingerprint(key))
return False

# RSA fallback
try:
key.verify(
sig,
payload,
padding.PKCS1v15(),
hashes.SHA256()
)
logger.debug("RSA verification success (pub fingerprint=%s)", _safe_pub_fingerprint(key))
return True
except Exception:
logger.debug("RSA verification failed (pub fingerprint=%s)", _safe_pub_fingerprint(key))
return False

def _safe_pub_fingerprint(key) -> str:
"""Return a short public fingerprint for safe logging. If obtaining bytes fails, return 'unknown'."""
try:
pub_bytes = key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
)
return _sha256_fingerprint(pub_bytes)
except Exception:
return "unknown"
62 changes: 62 additions & 0 deletions src/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# app/main.py
import logging
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, APIRouter
from fastapi.responses import JSONResponse
from app.manager import TicketScanManager, logger as manager_logger
from app.schemas import TicketScan
from datetime import datetime

# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s"
)
logger = logging.getLogger("ticket_scans.app")
# Keep manager logger consistent
manager_logger.setLevel(logging.INFO)

app = FastAPI(title="Ticket Scans WebSocket Service")
router = APIRouter()
manager = TicketScanManager()

@router.websocket("/ws/ticket-scans")
async def websocket_ticket_scans(ws: WebSocket):
"""
WebSocket endpoint that keeps connection open and sends scans when they are broadcast.
"""
await manager.connect(ws)
try:
# Keep the connection alive; optionally handle incoming messages if needed.
while True:
# Wait for any message from client; if you don't expect messages you can await ws.receive_text()
# but we will use receive to detect disconnects from client side.
try:
data = await ws.receive_text()
# For now, we simply ignore messages from clients but log them
logger.info("Received message from client (ignored): %s", data)
except Exception:
# The client may close the connection — break to disconnect and cleanup
break
except WebSocketDisconnect:
logger.info("Client disconnected via WebSocketDisconnect.")
except Exception as e:
logger.exception("Unexpected error in websocket loop: %s", e)
finally:
await manager.disconnect(ws)

@router.post("/scans", response_class=JSONResponse)
async def post_scan(scan: TicketScan):
"""
POST endpoint to accept a scan and broadcast it to clients.
In production, scanning devices/services would typically call this API when a ticket is scanned,
or you would call manager.broadcast_scan from inside your event pipeline.
"""
payload = scan.dict()
# Optionally add server-received timestamp
payload.setdefault("server_received_at", datetime.utcnow().isoformat())
# Broadcast but don't block the response when there are many clients (we await because manager.broadcast_scan is async)
await manager.broadcast_scan(payload)
logger.info("Received scan for ticket_id=%s event_id=%s", scan.ticket_id, scan.event_id)
return {"ok": True}

app.include_router(router)
37 changes: 37 additions & 0 deletions tests/test_sign_verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# tests/test_sign_verify.py
import tempfile
import os
from cryptography.hazmat.primitives.asymmetric import ed25519
from cryptography.hazmat.primitives import serialization
from src.signer import sign, verify
from importlib import reload
import src.key_manager as key_manager

def test_ed25519_sign_verify_ephemeral(monkeypatch):
# generate ephemeral keys
priv = ed25519.Ed25519PrivateKey.generate()
pub = priv.public_key()
priv_pem = priv.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
).decode("utf-8")
pub_pem = pub.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode("utf-8")

# monkeypatch environment variables for the config loader
monkeypatch.setenv("PRIVATE_KEY_PEM", priv_pem)
monkeypatch.setenv("PUBLIC_KEY_PEM", pub_pem)

# reload modules to pick up monkeypatched env
reload(key_manager)
from src import signer as signer_module

payload = b"unit test payload"
signature = signer_module.sign(payload)
assert signer_module.verify(payload, signature) is True

# tamper payload -> verify should fail
assert signer_module.verify(payload + b"x", signature) is False
Loading
Loading