diff --git a/contracts/api/openapi.yaml b/contracts/api/openapi.yaml index 99597b5..09d3c15 100644 --- a/contracts/api/openapi.yaml +++ b/contracts/api/openapi.yaml @@ -23,6 +23,39 @@ paths: status: type: string enum: [ok] + /health/live: + get: + summary: Liveness probe + operationId: getHealthLive + responses: + '200': + description: Process is alive + content: + application/json: + schema: + type: object + required: [status] + properties: + status: + type: string + enum: [ok] + /health/ready: + get: + summary: Readiness probe + operationId: getHealthReady + responses: + '200': + description: Service is ready + content: + application/json: + schema: + $ref: '#/components/schemas/HealthReadyResponse' + '503': + description: Service is degraded + content: + application/json: + schema: + $ref: '#/components/schemas/HealthReadyResponse' /metrics: get: summary: Runtime metrics snapshot @@ -107,6 +140,116 @@ paths: application/json: schema: $ref: '#/components/schemas/ErrorResponse' + /v1/moderate/batch: + post: + summary: Moderate a batch of texts + operationId: moderateBatch + security: + - ApiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ModerationBatchRequest' + responses: + '200': + description: Batch moderation result + headers: + X-RateLimit-Limit: + description: Maximum requests allowed in the current window. + schema: + type: string + X-RateLimit-Remaining: + description: Requests remaining in the current window. + schema: + type: string + X-RateLimit-Reset: + description: Seconds until the current rate-limit window resets. + schema: + type: string + content: + application/json: + schema: + $ref: '#/components/schemas/ModerationBatchResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '429': + description: Rate limit exceeded + headers: + X-RateLimit-Limit: + description: Maximum requests allowed in the current window. + schema: + type: string + X-RateLimit-Remaining: + description: Always `0` when the request is throttled. + schema: + type: string + X-RateLimit-Reset: + description: Seconds until the current rate-limit window resets. + schema: + type: string + Retry-After: + description: Seconds the client should wait before retrying. + schema: + type: string + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + /v1/appeals: + post: + summary: Submit an appeal + operationId: createPublicAppeal + security: + - ApiKeyAuth: [] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/PublicAppealCreateRequest' + responses: + '201': + description: Appeal submitted + content: + application/json: + schema: + $ref: '#/components/schemas/PublicAppealCreateResponse' + '400': + description: Invalid request + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '401': + description: Unauthorized + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '429': + description: Rate limit exceeded + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + '500': + description: Internal server error + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' components: securitySchemes: ApiKeyAuth: @@ -114,6 +257,19 @@ components: in: header name: X-API-Key schemas: + HealthReadyResponse: + type: object + additionalProperties: false + required: [status, checks] + properties: + status: + type: string + enum: [ready, degraded] + checks: + type: object + additionalProperties: + type: string + enum: [ok, empty, error] ModerationRequest: type: object additionalProperties: false @@ -125,22 +281,144 @@ components: minLength: 1 maxLength: 5000 context: - type: object - additionalProperties: false - properties: - source: - type: string - maxLength: 100 - locale: - type: string - maxLength: 20 - channel: - type: string - maxLength: 50 + $ref: '#/components/schemas/ModerationContext' request_id: type: string maxLength: 128 pattern: '^[A-Za-z0-9][A-Za-z0-9._:-]{0,127}$' + ModerationBatchItem: + type: object + additionalProperties: false + required: [text] + properties: + text: + type: string + minLength: 1 + maxLength: 5000 + context: + $ref: '#/components/schemas/ModerationContext' + request_id: + type: string + maxLength: 128 + pattern: '^[A-Za-z0-9][A-Za-z0-9._:-]{0,127}$' + ModerationBatchRequest: + type: object + additionalProperties: false + required: [items] + properties: + items: + type: array + minItems: 1 + maxItems: 50 + items: + $ref: '#/components/schemas/ModerationBatchItem' + ModerationBatchItemResult: + type: object + additionalProperties: false + required: [request_id] + properties: + request_id: + type: string + maxLength: 128 + result: + anyOf: + - $ref: '#/components/schemas/ModerationResponse' + - type: 'null' + error: + anyOf: + - $ref: '#/components/schemas/ErrorResponse' + - type: 'null' + ModerationBatchResponse: + type: object + additionalProperties: false + required: [items, total, succeeded, failed] + properties: + items: + type: array + items: + $ref: '#/components/schemas/ModerationBatchItemResult' + total: + type: integer + minimum: 0 + succeeded: + type: integer + minimum: 0 + failed: + type: integer + minimum: 0 + ModerationContext: + type: object + additionalProperties: false + properties: + source: + type: string + maxLength: 100 + locale: + type: string + maxLength: 20 + channel: + type: string + maxLength: 50 + PublicAppealCreateRequest: + type: object + additionalProperties: false + required: + - decision_request_id + - original_action + - original_reason_codes + - original_model_version + - original_lexicon_version + - original_policy_version + - original_pack_versions + properties: + decision_request_id: + type: string + minLength: 1 + maxLength: 128 + original_action: + type: string + enum: [ALLOW, REVIEW, BLOCK] + original_reason_codes: + type: array + minItems: 1 + items: + type: string + pattern: '^R_[A-Z0-9_]+$' + original_model_version: + type: string + minLength: 1 + maxLength: 128 + original_lexicon_version: + type: string + minLength: 1 + maxLength: 128 + original_policy_version: + type: string + minLength: 1 + maxLength: 128 + original_pack_versions: + type: object + minProperties: 1 + additionalProperties: + type: string + reason: + type: string + maxLength: 500 + PublicAppealCreateResponse: + type: object + additionalProperties: false + required: [appeal_id, status, request_id] + properties: + appeal_id: + type: integer + minimum: 1 + status: + type: string + enum: [submitted] + request_id: + type: string + minLength: 1 + maxLength: 128 ModerationResponse: type: object additionalProperties: false diff --git a/migrations/0013_multi_model_embeddings.sql b/migrations/0013_multi_model_embeddings.sql new file mode 100644 index 0000000..5bb1cb1 --- /dev/null +++ b/migrations/0013_multi_model_embeddings.sql @@ -0,0 +1,51 @@ +-- 0013_multi_model_embeddings.sql +-- +-- Adds a multi-model embedding table to support coexisting embedding backends +-- (e.g. 64-dim hash-bow-v1 and 384-dim e5-multilingual-small-v1) while keeping +-- per-model ANN indexes dimension-consistent via partial indexes. +-- +-- This migration is additive: it does not modify or drop lexicon_entry_embeddings (v1). + +CREATE TABLE IF NOT EXISTS lexicon_entry_embeddings_v2 ( + lexicon_entry_id BIGINT NOT NULL + REFERENCES lexicon_entries (id) + ON DELETE CASCADE, + embedding_model TEXT NOT NULL, + embedding_dim INT NOT NULL, + embedding VECTOR NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (lexicon_entry_id, embedding_model), + CONSTRAINT lexicon_entry_embeddings_v2_embedding_dim_check + CHECK (vector_dims(embedding) = embedding_dim) +); + +CREATE INDEX IF NOT EXISTS ix_lex_emb_v2_model +ON lexicon_entry_embeddings_v2 (embedding_model, updated_at DESC); + +-- Partial ANN indexes (IVFFlat requires a fixed vector dimension per index). +CREATE INDEX IF NOT EXISTS ix_lex_emb_v2_hash_bow_v1 +ON lexicon_entry_embeddings_v2 +USING ivfflat (embedding vector_cosine_ops) +WITH (lists = 32) +WHERE embedding_model = 'hash-bow-v1'; + +CREATE INDEX IF NOT EXISTS ix_lex_emb_v2_e5_small_v1 +ON lexicon_entry_embeddings_v2 +USING ivfflat (embedding vector_cosine_ops) +WITH (lists = 32) +WHERE embedding_model = 'e5-multilingual-small-v1'; + +-- Backfill existing hash-bow-v1 vectors from v1 to v2 (idempotent). +INSERT INTO lexicon_entry_embeddings_v2 + (lexicon_entry_id, embedding_model, embedding_dim, embedding, created_at, updated_at) +SELECT + lexicon_entry_id, + 'hash-bow-v1', + 64, + embedding, + created_at, + updated_at +FROM lexicon_entry_embeddings +WHERE embedding_model = 'hash-bow-v1' +ON CONFLICT (lexicon_entry_id, embedding_model) DO NOTHING; diff --git a/pyproject.toml b/pyproject.toml index 31fb89d..d4ab7ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,8 +10,10 @@ dependencies = [ "prometheus-client", "pydantic", "psycopg[binary]", + "psycopg-pool", "redis", "structlog", + "tenacity", "uvicorn", "PyJWT", ] @@ -27,7 +29,6 @@ dev = [ ] ops = [ "alembic", - "tenacity", ] ml = [ "fasttext-wheel", diff --git a/src/sentinel_api/appeals.py b/src/sentinel_api/appeals.py index 2c097b6..5290302 100644 --- a/src/sentinel_api/appeals.py +++ b/src/sentinel_api/appeals.py @@ -5,11 +5,14 @@ import os from dataclasses import dataclass, field from datetime import UTC, datetime +from pathlib import Path from threading import Lock from typing import Any, Literal, cast, get_args +from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field +from sentinel_api.logging import get_logger from sentinel_core.async_state_machine import validate_appeal_transition from sentinel_core.models import Action, ReasonCode @@ -43,6 +46,9 @@ KNOWN_RESOLVED_APPEAL_STATUSES = set(get_args(ResolvedAppealStatus)) KNOWN_ACTIONS = set(get_args(Action)) +logger = get_logger("sentinel.appeals") +TRAINING_DATA_PATH_ENV = "SENTINEL_TRAINING_DATA_PATH" + class AppealNotFoundError(LookupError): pass @@ -351,6 +357,16 @@ def reconstruct(self, *, appeal_id: int) -> AdminAppealReconstructionResponse: class _PostgresAppealsStore: database_url: str + def _connection(self): + from sentinel_api.db_pool import get_pool + + pool = get_pool(self.database_url) + if pool is not None: + return pool.connection() + + psycopg = importlib.import_module("psycopg") + return psycopg.connect(self.database_url) + def _fetch_appeal_record(self, cur, appeal_id: int) -> AdminAppealRecord: cur.execute( """ @@ -388,8 +404,7 @@ def create_appeal( *, submitted_by: str, ) -> AdminAppealRecord: - psycopg = importlib.import_module("psycopg") - with psycopg.connect(self.database_url) as conn: + with self._connection() as conn: with conn.cursor() as cur: cur.execute( """ @@ -458,8 +473,7 @@ def list_appeals( if where_conditions: where_clause = "WHERE " + " AND ".join(where_conditions) - psycopg = importlib.import_module("psycopg") - with psycopg.connect(self.database_url) as conn: + with self._connection() as conn: with conn.cursor() as cur: cur.execute( f"SELECT COUNT(1) FROM appeals {where_clause}", @@ -506,8 +520,7 @@ def transition_appeal( payload: AdminAppealTransitionRequest, actor: str, ) -> AdminAppealRecord: - psycopg = importlib.import_module("psycopg") - with psycopg.connect(self.database_url) as conn: + with self._connection() as conn: with conn.cursor() as cur: current = self._fetch_appeal_record(cur, appeal_id) validate_appeal_transition(current.status, payload.to_status) @@ -557,12 +570,27 @@ def transition_appeal( ), ) updated = self._fetch_appeal_record(cur, appeal_id) + if payload.to_status in REVERSED_OR_MODIFIED_STATUSES: + _auto_create_lexicon_proposal( + cur, + appeal=updated, + resolution_reason_codes=list(resolution_reason_codes or []), + actor=actor, + ) conn.commit() + if updated.status in REVERSED_OR_MODIFIED_STATUSES: + try: + _emit_training_sample(updated) + except Exception as exc: # pragma: no cover - must never fail caller + logger.warning( + "training_sample_emit_crashed", + appeal_id=updated.id, + error=str(exc), + ) return updated def reconstruct(self, *, appeal_id: int) -> AdminAppealReconstructionResponse: - psycopg = importlib.import_module("psycopg") - with psycopg.connect(self.database_url) as conn: + with self._connection() as conn: with conn.cursor() as cur: appeal = self._fetch_appeal_record(cur, appeal_id) cur.execute( @@ -643,6 +671,85 @@ def _build_reconstruction( ) +def _auto_create_lexicon_proposal( + cur, + *, + appeal: AdminAppealRecord, + resolution_reason_codes: list[str], + actor: str, +) -> int: + title = f"Auto-proposed from appeal #{appeal.id}: {appeal.resolution_code}" + evidence = { + "appeal_id": appeal.id, + "original_action": appeal.original_action, + "original_reason_codes": list(appeal.original_reason_codes), + "resolution_reason_codes": list(resolution_reason_codes), + "original_lexicon_version": appeal.original_lexicon_version, + } + policy_impact_summary = f"appeal_reversal request_id={appeal.request_id}" + cur.execute( + """ + INSERT INTO release_proposals + (proposal_type, status, title, evidence, policy_impact_summary, proposed_by, updated_at) + VALUES + ('lexicon', 'draft', %s, %s::jsonb, %s, %s, NOW()) + RETURNING id + """, + ( + title, + json.dumps(evidence, sort_keys=True, ensure_ascii=True), + policy_impact_summary, + actor, + ), + ) + row = cur.fetchone() + if row is None: + raise ValueError("failed to auto-create release proposal") + proposal_id = int(row[0]) + cur.execute( + """ + INSERT INTO release_proposal_audit + (proposal_id, from_status, to_status, actor, details) + VALUES + (%s, %s, 'draft', %s, %s) + """, + (proposal_id, None, actor, f"auto-generated from appeal_id={appeal.id}"), + ) + return proposal_id + + +def _emit_training_sample(appeal: AdminAppealRecord) -> None: + path_value = os.getenv(TRAINING_DATA_PATH_ENV, "").strip() + if not path_value: + return + record = { + "sample_id": uuid4().hex, + "appeal_id": appeal.id, + "original_action": appeal.original_action, + "original_reason_codes": list(appeal.original_reason_codes), + "original_model_version": appeal.original_model_version, + "original_lexicon_version": appeal.original_lexicon_version, + "original_policy_version": appeal.original_policy_version, + "resolution": appeal.status, + "resolution_code": appeal.resolution_code, + "resolution_reason_codes": ( + list(appeal.resolution_reason_codes) + if appeal.resolution_reason_codes is not None + else None + ), + "reviewer_actor": appeal.reviewer_actor, + "resolved_at": appeal.resolved_at.isoformat() if appeal.resolved_at is not None else None, + } + path = Path(path_value) + try: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + handle.write(json.dumps(record, ensure_ascii=True)) + handle.write("\n") + except OSError as exc: + logger.warning("training_sample_emit_failed", path=str(path), error=str(exc)) + + class AppealsRuntime: def __init__(self) -> None: self._memory_store = _InMemoryAppealsStore() diff --git a/src/sentinel_api/async_worker.py b/src/sentinel_api/async_worker.py index 73c0fb2..90be214 100644 --- a/src/sentinel_api/async_worker.py +++ b/src/sentinel_api/async_worker.py @@ -352,9 +352,13 @@ def process_one( max_retry_attempts: int = DEFAULT_MAX_RETRY_ATTEMPTS, max_error_retry_seconds: int = DEFAULT_MAX_ERROR_RETRY_SECONDS, ) -> WorkerRunReport: + from sentinel_api.db_pool import get_pool + psycopg = _get_psycopg_module() claimed_item: QueueWorkItem | None = None - with psycopg.connect(database_url) as conn: + pool = get_pool(database_url) + conn_ctx = pool.connection() if pool is not None else psycopg.connect(database_url) + with conn_ctx as conn: try: with conn.cursor() as cur: claimed_item = _claim_next_queue_item(cur) diff --git a/src/sentinel_api/audit_events.py b/src/sentinel_api/audit_events.py new file mode 100644 index 0000000..76e1c66 --- /dev/null +++ b/src/sentinel_api/audit_events.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import json +from collections import deque +from dataclasses import asdict, dataclass +from threading import Lock + +AUDIT_RING_BUFFER_SIZE = 1000 + + +@dataclass(frozen=True) +class AuditEvent: + timestamp: str + action: str + labels: list[str] + reason_codes: list[str] + latency_ms: int + deployment_stage: str + lexicon_version: str + policy_version: str + + +_lock = Lock() +_sequence: int = 0 +_ring: deque[tuple[int, AuditEvent]] = deque(maxlen=AUDIT_RING_BUFFER_SIZE) + + +def publish_audit_event(event: AuditEvent) -> None: + global _sequence + with _lock: + _sequence += 1 + _ring.append((_sequence, event)) + + +def events_since(cursor: int) -> tuple[list[AuditEvent], int]: + normalized_cursor = max(0, int(cursor)) + with _lock: + events = [event for seq, event in _ring if seq > normalized_cursor] + return events, _sequence + + +def _format_sse_event(event: AuditEvent) -> str: + return f"data: {json.dumps(asdict(event), ensure_ascii=True)}\n\n" + + +def reset_audit_events_state() -> None: + global _sequence + with _lock: + _sequence = 0 + _ring.clear() diff --git a/src/sentinel_api/db_pool.py b/src/sentinel_api/db_pool.py new file mode 100644 index 0000000..137061c --- /dev/null +++ b/src/sentinel_api/db_pool.py @@ -0,0 +1,11 @@ +"""Compatibility shim for DB pooling. + +The pool singleton lives in `sentinel_db.pool` so non-API packages can use pooling +without importing from `sentinel_api`. +""" + +from __future__ import annotations + +from sentinel_db.pool import close_pool, get_pool, peek_pool + +__all__ = ["get_pool", "peek_pool", "close_pool"] diff --git a/src/sentinel_api/main.py b/src/sentinel_api/main.py index 3352e44..13de3f6 100644 --- a/src/sentinel_api/main.py +++ b/src/sentinel_api/main.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import os import re @@ -7,6 +8,7 @@ import time from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager +from dataclasses import asdict from datetime import UTC, datetime from pathlib import Path as FilePath from typing import Literal @@ -24,7 +26,7 @@ status, ) from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse, PlainTextResponse +from fastapi.responses import JSONResponse, PlainTextResponse, StreamingResponse from pydantic import BaseModel, ConfigDict, Field from sentinel_api.appeals import ( @@ -38,12 +40,16 @@ get_appeals_runtime, ) from sentinel_api.async_priority import async_queue_metrics +from sentinel_api.audit_events import AuditEvent, events_since, publish_audit_event +from sentinel_api.db_pool import close_pool, get_pool from sentinel_api.logging import get_logger from sentinel_api.metrics import metrics +from sentinel_api.model_artifact_repository import resolve_runtime_model_version from sentinel_api.model_registry import predict_classifier_shadow from sentinel_api.oauth import OAuthPrincipal, require_oauth_scope from sentinel_api.policy import moderate from sentinel_api.rate_limit import build_rate_limiter +from sentinel_api.result_cache import get_cached_result, make_cache_key, set_cached_result from sentinel_api.transparency import ( TransparencyAppealsExportResponse, TransparencyAppealsReportResponse, @@ -53,14 +59,28 @@ Action, ErrorResponse, MetricsResponse, + ModerationBatchItemResult, + ModerationBatchRequest, + ModerationBatchResponse, ModerationRequest, ModerationResponse, + PublicAppealCreateRequest, + PublicAppealCreateResponse, ) -from sentinel_core.policy_config import DeploymentStage, resolve_policy_runtime +from sentinel_core.policy_config import ( + DeploymentStage, + ElectoralPhase, + resolve_policy_runtime, + set_runtime_phase_override, +) +from sentinel_langpack.registry import resolve_pack_versions +from sentinel_lexicon.lexicon import get_lexicon_matcher logger = get_logger("sentinel.api") CLASSIFIER_SHADOW_ENABLED_ENV = "SENTINEL_CLASSIFIER_SHADOW_ENABLED" SHADOW_PREDICTIONS_PATH_ENV = "SENTINEL_SHADOW_PREDICTIONS_PATH" +RESULT_CACHE_ENABLED_ENV = "SENTINEL_RESULT_CACHE_ENABLED" +RESULT_CACHE_TTL_SECONDS_ENV = "SENTINEL_RESULT_CACHE_TTL_SECONDS" _REQUEST_ID_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._:-]{0,127}$") @@ -81,7 +101,13 @@ def _coerce_request_id(value: str | None) -> str | None: async def lifespan(_app: FastAPI) -> AsyncIterator[None]: # Fail fast if electoral phase override is invalid. resolve_policy_runtime() - yield + database_url = os.getenv("SENTINEL_DATABASE_URL", "").strip() + if database_url: + get_pool(database_url) + try: + yield + finally: + close_pool() app = FastAPI(title="Sentinel Moderation API", version="0.1.0", lifespan=lifespan) @@ -107,6 +133,12 @@ class AdminProposalReviewResponse(BaseModel): rationale: str | None = None +class AdminPhaseUpdateRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + phase: ElectoralPhase | None = None + + @app.middleware("http") async def request_context_middleware(request: Request, call_next): # type: ignore[no-untyped-def] request_id = _coerce_request_id(request.headers.get("X-Request-ID")) or str(uuid4()) @@ -143,8 +175,12 @@ def require_api_key(x_api_key: str | None = Header(default=None)) -> None: def enforce_rate_limit(response: Response, x_api_key: str | None = Header(default=None)) -> None: + _enforce_rate_limit_cost(response, x_api_key=x_api_key, cost=1) + + +def _enforce_rate_limit_cost(response: Response, *, x_api_key: str | None, cost: int) -> None: key = x_api_key or "anonymous" - decision = rate_limiter.check(key) + decision = rate_limiter.check(key, cost=cost) response.headers["X-RateLimit-Limit"] = str(decision.limit) response.headers["X-RateLimit-Remaining"] = str(decision.remaining) response.headers["X-RateLimit-Reset"] = str(decision.reset_after_seconds) @@ -257,6 +293,85 @@ def health() -> dict[str, str]: return {"status": "ok"} +@app.get("/health/live") +def health_live() -> dict[str, str]: + return {"status": "ok"} + + +def _check_lexicon_ready() -> str: + try: + matcher = get_lexicon_matcher() + except Exception: + return "error" + if matcher.entries: + return "ok" + return "empty" + + +def _check_db_ready(database_url: str) -> str: + normalized = database_url.strip() + if not normalized: + return "empty" + try: + from sentinel_api.db_pool import get_pool + + pool = get_pool(normalized) + if pool is not None: + conn_ctx = pool.connection() + else: + import psycopg + + conn_ctx = psycopg.connect(normalized) + with conn_ctx as conn: + with conn.cursor() as cur: + cur.execute("SELECT 1") + cur.fetchone() + except Exception: + return "error" + return "ok" + + +def _check_redis_ready(redis_url: str) -> str: + normalized = redis_url.strip() + if not normalized: + return "empty" + try: + import redis + + client = redis.Redis.from_url( + normalized, + socket_connect_timeout=0.5, + socket_timeout=0.5, + health_check_interval=10, + ) + client.ping() + except Exception: + return "error" + return "ok" + + +@app.get("/health/ready") +def health_ready() -> JSONResponse: + checks: dict[str, str] = {} + checks["lexicon"] = _check_lexicon_ready() + + database_url = os.getenv("SENTINEL_DATABASE_URL", "") + if database_url.strip(): + checks["db"] = _check_db_ready(database_url) + + redis_url = os.getenv("SENTINEL_REDIS_URL", "") + if redis_url.strip(): + checks["redis"] = _check_redis_ready(redis_url) + + degraded = any(value == "error" for value in checks.values()) + status_value = "degraded" if degraded else "ready" + http_status = status.HTTP_503_SERVICE_UNAVAILABLE if degraded else status.HTTP_200_OK + return JSONResponse( + status_code=http_status, + content={"status": status_value, "checks": checks}, + ) + + @app.get("/metrics", response_model=MetricsResponse) def get_metrics() -> MetricsResponse: snapshot = metrics.snapshot() @@ -296,6 +411,44 @@ def get_admin_proposal_permissions( } +@app.post("/admin/policy/phase") +def post_admin_policy_phase( + request: AdminPhaseUpdateRequest, + principal: OAuthPrincipal = Depends(require_oauth_scope("admin:policy:write")), +) -> dict[str, object]: + set_runtime_phase_override(request.phase) + runtime = resolve_policy_runtime() + effective_phase = runtime.effective_phase.value if runtime.effective_phase is not None else None + return { + "effective_phase": effective_phase, + "effective_policy_version": runtime.effective_policy_version, + "actor": principal.client_id, + "limitation": ( + "in-process only; multi-worker and multi-replica deployments require a shared store" + ), + } + + +async def _generate_audit_sse(start_cursor: int) -> AsyncIterator[str]: + cursor = max(0, int(start_cursor)) + while True: + events, cursor = events_since(cursor) + for event in events: + payload = json.dumps(asdict(event), ensure_ascii=True) + yield f"data: {payload}\n\n" + if not events: + await asyncio.sleep(0.5) + + +@app.get("/admin/audit/stream") +def get_admin_audit_stream( + cursor: int = Query(default=0, ge=0), + principal: OAuthPrincipal = Depends(require_oauth_scope("admin:transparency:read")), +) -> StreamingResponse: + _ = principal + return StreamingResponse(_generate_audit_sse(cursor), media_type="text/event-stream") + + @app.post( "/admin/release-proposals/{proposal_id}/review", response_model=AdminProposalReviewResponse, @@ -492,12 +645,66 @@ def moderate_text( ) effective_request_id = request.request_id or http_request.state.request_id runtime = resolve_policy_runtime() - result = moderate(request.text, runtime=runtime) + response.headers["X-Request-ID"] = effective_request_id + + cache_enabled = _is_truthy_env(RESULT_CACHE_ENABLED_ENV) + redis_url = os.getenv("SENTINEL_REDIS_URL", "").strip() + cache_key: str | None = None + if cache_enabled and redis_url: + matcher = get_lexicon_matcher() + cache_key = make_cache_key( + request.text, + policy_version=runtime.effective_policy_version, + lexicon_version=matcher.version, + model_version=resolve_runtime_model_version(runtime.config.model_version), + pack_versions=resolve_pack_versions(runtime.config.pack_versions), + deployment_stage=runtime.effective_deployment_stage.value, + context=request.context, + ) + cached = get_cached_result(cache_key, redis_url) + if cached is not None: + response.headers["X-Cache"] = "HIT" + metrics.record_action(cached.action) + metrics.record_moderation_latency(cached.latency_ms) + publish_audit_event( + AuditEvent( + timestamp=datetime.now(tz=UTC).isoformat(), + action=cached.action, + labels=list(cached.labels), + reason_codes=list(cached.reason_codes), + latency_ms=cached.latency_ms, + deployment_stage=runtime.effective_deployment_stage.value, + lexicon_version=cached.lexicon_version, + policy_version=cached.policy_version, + ) + ) + return cached + response.headers["X-Cache"] = "MISS" + + result = moderate(request.text, context=request.context, runtime=runtime) effective_phase = runtime.effective_phase.value if runtime.effective_phase is not None else None effective_deployment_stage = runtime.effective_deployment_stage.value - response.headers["X-Request-ID"] = effective_request_id metrics.record_action(result.action) metrics.record_moderation_latency(result.latency_ms) + publish_audit_event( + AuditEvent( + timestamp=datetime.now(tz=UTC).isoformat(), + action=result.action, + labels=list(result.labels), + reason_codes=list(result.reason_codes), + latency_ms=result.latency_ms, + deployment_stage=runtime.effective_deployment_stage.value, + lexicon_version=result.lexicon_version, + policy_version=result.policy_version, + ) + ) + if cache_key is not None and redis_url: + ttl_raw = os.getenv(RESULT_CACHE_TTL_SECONDS_ENV, "60").strip() + try: + ttl = int(ttl_raw) + except ValueError: + ttl = 60 + set_cached_result(cache_key, result, redis_url, ttl=ttl) _record_classifier_shadow_prediction( request_id=effective_request_id, text=request.text, @@ -520,6 +727,125 @@ def moderate_text( return result +@app.post( + "/v1/moderate/batch", + response_model=ModerationBatchResponse, + responses={ + 401: {"model": ErrorResponse}, + 429: {"model": ErrorResponse}, + 500: {"model": ErrorResponse}, + }, +) +def moderate_batch( + http_request: Request, + response: Response, + request: ModerationBatchRequest, + _: None = Depends(require_api_key), + x_api_key: str | None = Header(default=None), +) -> ModerationBatchResponse: + effective_request_id = http_request.state.request_id + response.headers["X-Request-ID"] = effective_request_id + + _enforce_rate_limit_cost(response, x_api_key=x_api_key, cost=len(request.items)) + + runtime = resolve_policy_runtime() + items: list[ModerationBatchItemResult] = [] + succeeded = 0 + failed = 0 + + for item in request.items: + item_request_id = item.request_id or str(uuid4()) + if _coerce_request_id(item_request_id) is None: + failed += 1 + items.append( + ModerationBatchItemResult( + request_id=item_request_id, + result=None, + error=ErrorResponse( + error_code="HTTP_400", + message="request_id contains invalid characters", + request_id=item_request_id, + ), + ) + ) + continue + + try: + result = moderate(item.text, context=item.context, runtime=runtime) + except Exception: + failed += 1 + items.append( + ModerationBatchItemResult( + request_id=item_request_id, + result=None, + error=ErrorResponse( + error_code="HTTP_500", + message="Internal server error", + request_id=item_request_id, + ), + ) + ) + continue + + succeeded += 1 + items.append( + ModerationBatchItemResult( + request_id=item_request_id, + result=result, + error=None, + ) + ) + + return ModerationBatchResponse( + items=items, + total=len(items), + succeeded=succeeded, + failed=failed, + ) + + +@app.post( + "/v1/appeals", + response_model=PublicAppealCreateResponse, + status_code=status.HTTP_201_CREATED, + responses={ + 400: {"model": ErrorResponse}, + 401: {"model": ErrorResponse}, + 429: {"model": ErrorResponse}, + 500: {"model": ErrorResponse}, + }, +) +def post_public_appeal( + request: PublicAppealCreateRequest, + _: None = Depends(require_api_key), + __: None = Depends(enforce_rate_limit), +) -> PublicAppealCreateResponse: + if _coerce_request_id(request.decision_request_id) is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="decision_request_id contains invalid characters", + ) + record = appeals_runtime.create_appeal( + AdminAppealCreateRequest( + original_decision_id=request.decision_request_id, + request_id=request.decision_request_id, + original_action=request.original_action, + original_reason_codes=request.original_reason_codes, + original_model_version=request.original_model_version, + original_lexicon_version=request.original_lexicon_version, + original_policy_version=request.original_policy_version, + original_pack_versions=request.original_pack_versions, + rationale=request.reason, + ), + submitted_by="public-api", + ) + return PublicAppealCreateResponse( + appeal_id=record.id, + status="submitted", + request_id=record.request_id, + ) + + @app.exception_handler(HTTPException) async def http_exception_handler(request: Request, exc: HTTPException): # type: ignore[no-untyped-def] request_id = getattr(request.state, "request_id", str(uuid4())) diff --git a/src/sentinel_api/model_registry.py b/src/sentinel_api/model_registry.py index 9ed80c0..314d9e1 100644 --- a/src/sentinel_api/model_registry.py +++ b/src/sentinel_api/model_registry.py @@ -67,6 +67,54 @@ def embed(self, text: str, *, timeout_ms: int) -> list[float] | None: return None +class E5MultilingualSmallEmbeddingProvider: + name = "e5-multilingual-small" + version = "e5-multilingual-small-v1" + dimension = 384 + + @staticmethod + @lru_cache(maxsize=1) + def _load_model(): + try: + from sentence_transformers import SentenceTransformer + + return SentenceTransformer("intfloat/multilingual-e5-small") + except ImportError: + logger.warning("sentence-transformers not installed; e5 provider unavailable") + return None + except Exception as exc: + logger.warning("failed to load e5 model: %s", exc) + return None + + def embed(self, text: str, *, timeout_ms: int) -> list[float] | None: + _ = timeout_ms + model = self._load_model() + if model is None: + return None + try: + embedding = model.encode(f"query: {text}", normalize_embeddings=True) + if hasattr(embedding, "tolist"): + return embedding.tolist() + return list(embedding) + except Exception as exc: + logger.warning("e5 embedding provider failed; falling back: %s", exc) + return None + + def embed_passage(self, text: str, *, timeout_ms: int) -> list[float] | None: + _ = timeout_ms + model = self._load_model() + if model is None: + return None + try: + embedding = model.encode(f"passage: {text}", normalize_embeddings=True) + if hasattr(embedding, "tolist"): + return embedding.tolist() + return list(embedding) + except Exception as exc: + logger.warning("e5 passage embedding failed; falling back: %s", exc) + return None + + class NoopMultiLabelClassifier: name = "none" version = "none-v1" @@ -127,6 +175,7 @@ def score(self, text: str, *, timeout_ms: int) -> tuple[float, ClaimBand] | None EMBEDDING_PROVIDERS: dict[str, EmbeddingProvider] = { DEFAULT_EMBEDDING_PROVIDER_ID: HashBowEmbeddingProvider(), + "e5-multilingual-small-v1": E5MultilingualSmallEmbeddingProvider(), } CLASSIFIERS: dict[str, MultiLabelClassifier] = { DEFAULT_CLASSIFIER_PROVIDER_ID: NoopMultiLabelClassifier(), diff --git a/src/sentinel_api/oauth.py b/src/sentinel_api/oauth.py index ec01601..e89e40b 100644 --- a/src/sentinel_api/oauth.py +++ b/src/sentinel_api/oauth.py @@ -14,6 +14,20 @@ OAUTH_JWT_AUDIENCE_ENV = "SENTINEL_OAUTH_JWT_AUDIENCE" OAUTH_JWT_ISSUER_ENV = "SENTINEL_OAUTH_JWT_ISSUER" +KNOWN_OAUTH_SCOPES = frozenset( + { + "internal:queue:read", + "admin:proposal:read", + "admin:proposal:review", + "admin:appeal:read", + "admin:appeal:write", + "admin:policy:write", + "admin:transparency:read", + "admin:transparency:export", + "admin:transparency:identifiers", + } +) + @dataclass(frozen=True) class OAuthPrincipal: @@ -154,6 +168,9 @@ def authenticate_bearer_token(authorization: str | None) -> OAuthPrincipal: def require_oauth_scope(required_scope: str): + if required_scope not in KNOWN_OAUTH_SCOPES: + raise ValueError(f"Unknown OAuth scope: {required_scope}") + def dependency(authorization: str | None = Header(default=None)) -> OAuthPrincipal: principal = authenticate_bearer_token(authorization) if required_scope not in principal.scopes: diff --git a/src/sentinel_api/partner_connectors.py b/src/sentinel_api/partner_connectors.py index 69719d7..dba14e3 100644 --- a/src/sentinel_api/partner_connectors.py +++ b/src/sentinel_api/partner_connectors.py @@ -10,6 +10,7 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, ConfigDict, Field +from tenacity import Retrying, stop_after_attempt, wait_exponential from sentinel_api.async_priority import Priority, PrioritySignals, classify_priority, sla_due_at @@ -194,30 +195,44 @@ def fetch_signals( self._circuit_open_until = None retry_delays: list[int] = [] - last_error: str | None = None - - for attempt in range(1, self.max_attempts + 1): - try: - signals = self.connector.fetch_signals(since=since, limit=limit) - self._consecutive_failures = 0 - self._circuit_open_until = None - return ConnectorFetchOutcome( - status="ok", - connector_name=self.connector.name, - signals=signals, - attempts=attempt, - retry_delays_seconds=retry_delays, - ) - except Exception as exc: - last_error = str(exc) - if attempt < self.max_attempts: - delay_seconds = _retry_delay_seconds( - attempt=attempt, - base=self.base_backoff_seconds, - cap=self.max_backoff_seconds, + attempts = 0 + + def _sleep(seconds: float) -> None: + self._sleep_fn(int(seconds)) + + def _before_sleep(retry_state) -> None: # type: ignore[no-untyped-def] + next_action = getattr(retry_state, "next_action", None) + delay = getattr(next_action, "sleep", None) + if delay is None: + return + retry_delays.append(int(delay)) + + try: + for attempt in Retrying( + stop=stop_after_attempt(self.max_attempts), + wait=wait_exponential( + multiplier=self.base_backoff_seconds, + max=self.max_backoff_seconds, + ), + reraise=True, + sleep=_sleep, + before_sleep=_before_sleep, + ): + with attempt: + attempts = attempt.retry_state.attempt_number + signals = self.connector.fetch_signals(since=since, limit=limit) + self._consecutive_failures = 0 + self._circuit_open_until = None + return ConnectorFetchOutcome( + status="ok", + connector_name=self.connector.name, + signals=signals, + attempts=attempts, + retry_delays_seconds=retry_delays, ) - retry_delays.append(delay_seconds) - self._sleep_fn(delay_seconds) + except Exception as exc: + attempts = max(attempts, self.max_attempts) + last_error = str(exc) self._consecutive_failures += 1 if self._consecutive_failures >= self.circuit_failure_threshold: @@ -228,7 +243,7 @@ def fetch_signals( return ConnectorFetchOutcome( status="error", connector_name=self.connector.name, - attempts=self.max_attempts, + attempts=attempts, retry_delays_seconds=retry_delays, error=last_error, ) diff --git a/src/sentinel_api/policy.py b/src/sentinel_api/policy.py index 9aacd8c..d73ddf0 100644 --- a/src/sentinel_api/policy.py +++ b/src/sentinel_api/policy.py @@ -1,27 +1,43 @@ from __future__ import annotations +import os import time from dataclasses import dataclass from typing import cast, get_args +from sentinel_api.logging import get_logger from sentinel_api.model_artifact_repository import resolve_runtime_model_version -from sentinel_api.model_registry import score_claim_with_fallback +from sentinel_api.model_registry import ( + DEFAULT_MODEL_TIMEOUT_MS, + get_model_runtime, + score_claim_with_fallback, +) from sentinel_core.claim_likeness import contains_election_anchor from sentinel_core.model_runtime import ClaimBand -from sentinel_core.models import Action, EvidenceItem, Label, LanguageSpan, ModerationResponse +from sentinel_core.models import ( + Action, + EvidenceItem, + Label, + LanguageSpan, + ModerationContext, + ModerationResponse, +) from sentinel_core.policy_config import ( DeploymentStage, EffectivePolicyRuntime, get_policy_config, resolve_policy_runtime, ) +from sentinel_langpack import get_wave1_pack_matchers from sentinel_langpack.registry import resolve_pack_versions from sentinel_lexicon.hot_triggers import find_hot_trigger_matches from sentinel_lexicon.lexicon import get_lexicon_matcher from sentinel_lexicon.lexicon_repository import LexiconEntry -from sentinel_lexicon.vector_matcher import find_vector_match +from sentinel_lexicon.vector_matcher import DEFAULT_VECTOR_MATCH_THRESHOLD, find_vector_match from sentinel_router.language_router import detect_language_spans +logger = get_logger("sentinel.policy") + @dataclass class Decision: @@ -74,6 +90,38 @@ def _apply_deployment_stage( return decision +def _derive_toxicity( + decision: Decision, + *, + runtime: EffectivePolicyRuntime, +) -> float: + model_scores = [ + item.confidence + for item in decision.evidence + if item.type == "model_span" and item.confidence is not None + ] + base = getattr(runtime.toxicity_by_action, decision.action) + if not model_scores: + return base + model_score = max(model_scores) + blended = 0.6 * base + 0.4 * model_score + return round(min(1.0, max(0.0, blended)), 4) + + +def _finalize_decision(decision: Decision, *, runtime: EffectivePolicyRuntime) -> Decision: + staged = _apply_deployment_stage(decision, runtime=runtime) + toxicity = _derive_toxicity(staged, runtime=runtime) + if toxicity == staged.toxicity: + return staged + return Decision( + action=staged.action, + labels=staged.labels, + reason_codes=staged.reason_codes, + evidence=staged.evidence, + toxicity=toxicity, + ) + + def _band_from_score(score: float, *, medium_threshold: float, high_threshold: float) -> ClaimBand: if score >= high_threshold: return "high" @@ -111,7 +159,55 @@ def _deduplicate_entries(entries: list[LexiconEntry]) -> list[LexiconEntry]: return deduped -def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decision: +def _context_threshold_adjustment( + context: ModerationContext | None, + *, + runtime: EffectivePolicyRuntime, +) -> float: + del runtime + if context is None: + return 0.0 + channel = (context.channel or "").strip().lower() + if channel == "forward": + return -0.04 + if channel == "broadcast": + return 0.02 + return 0.0 + + +def _resolved_vector_match_threshold(runtime: EffectivePolicyRuntime) -> float: + if runtime.vector_match_threshold is not None: + return runtime.vector_match_threshold + raw = os.getenv("SENTINEL_VECTOR_MATCH_THRESHOLD") + if raw is None: + return DEFAULT_VECTOR_MATCH_THRESHOLD + try: + value = float(raw) + except ValueError: + return DEFAULT_VECTOR_MATCH_THRESHOLD + if value < 0 or value > 1: + return DEFAULT_VECTOR_MATCH_THRESHOLD + return value + + +def _vector_matching_configured() -> bool: + database_url = os.getenv("SENTINEL_DATABASE_URL", "").strip() + if not database_url: + return False + raw = os.getenv("SENTINEL_VECTOR_MATCH_ENABLED") + if raw is None: + return True + return raw.strip().lower() not in {"0", "false", "no", "off"} + + +def evaluate_text( + text: str, + matcher=None, + config=None, + runtime=None, + *, + context: ModerationContext | None = None, +) -> Decision: runtime = runtime or resolve_policy_runtime(config=config) config = runtime.config matcher = matcher or get_lexicon_matcher() @@ -153,7 +249,7 @@ def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decisio evidence=evidence, toxicity=runtime.toxicity_by_action.BLOCK, ) - return _apply_deployment_stage(decision, runtime=runtime) + return _finalize_decision(decision, runtime=runtime) review_matches = [entry for entry in matches if entry.action == "REVIEW"] @@ -177,13 +273,58 @@ def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decisio evidence=evidence, toxicity=runtime.toxicity_by_action.REVIEW, ) - return _apply_deployment_stage(decision, runtime=runtime) + return _finalize_decision(decision, runtime=runtime) - vector_match = find_vector_match( - text, - lexicon_version=matcher.version, - min_similarity=runtime.vector_match_threshold, - ) + pack_matchers = get_wave1_pack_matchers() + has_pack_matches = False + for pack_matcher in pack_matchers: + for entry in pack_matcher.match(text): + has_pack_matches = True + labels.append(_as_label(entry.label)) + reason_codes.append(entry.reason_code) + evidence.append( + EvidenceItem( + type="lexicon", + match=entry.term, + severity=entry.severity, + lang=pack_matcher.language, + ) + ) + + if has_pack_matches: + decision = Decision( + action="REVIEW", + labels=sorted(set(labels)), + reason_codes=sorted(set(reason_codes)), + evidence=evidence, + toxicity=runtime.toxicity_by_action.REVIEW, + ) + return _finalize_decision(decision, runtime=runtime) + + threshold_delta = _context_threshold_adjustment(context, runtime=runtime) + if threshold_delta: + logger.debug( + "context_threshold_adjustment", + channel=context.channel if context else None, + delta=threshold_delta, + ) + vector_threshold = _resolved_vector_match_threshold(runtime) + vector_threshold = max(0.0, min(1.0, vector_threshold + threshold_delta)) + + vector_match = None + if _vector_matching_configured(): + model_runtime = get_model_runtime() + embedding_provider = model_runtime.embedding_provider + embedding_model = model_runtime.embedding_provider_id + query_embedding = embedding_provider.embed(text, timeout_ms=DEFAULT_MODEL_TIMEOUT_MS) + if query_embedding is not None: + vector_match = find_vector_match( + text, + lexicon_version=matcher.version, + query_embedding=query_embedding, + embedding_model=embedding_model, + min_similarity=vector_threshold, + ) if vector_match is not None: entry = vector_match.entry # Safety posture: semantic/vector evidence is advisory and cannot directly @@ -206,13 +347,21 @@ def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decisio ], toxicity=toxicity, ) - return _apply_deployment_stage(decision, runtime=runtime) + return _finalize_decision(decision, runtime=runtime) claim_score = score_claim_with_fallback(text) if claim_score is not None: claim_score_value, _ = claim_score else: claim_score_value = 0.0 + if context is not None and (context.source or "").strip().lower() == "partner_factcheck": + adjusted = min(claim_score_value * 1.10, 1.0) + logger.debug( + "context_partner_factcheck_claim_multiplier", + base=claim_score_value, + adjusted=adjusted, + ) + claim_score_value = adjusted claim_band = _band_from_score( claim_score_value, medium_threshold=runtime.claim_likeness.medium_threshold, @@ -240,7 +389,7 @@ def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decisio ], toxicity=runtime.toxicity_by_action.REVIEW, ) - return _apply_deployment_stage(decision, runtime=runtime) + return _finalize_decision(decision, runtime=runtime) if runtime.no_match_action == "REVIEW": decision = Decision( @@ -256,7 +405,7 @@ def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decisio ], toxicity=runtime.toxicity_by_action.REVIEW, ) - return _apply_deployment_stage(decision, runtime=runtime) + return _finalize_decision(decision, runtime=runtime) decision = Decision( action="ALLOW", @@ -271,15 +420,20 @@ def evaluate_text(text: str, matcher=None, config=None, runtime=None) -> Decisio ], toxicity=runtime.toxicity_by_action.ALLOW, ) - return _apply_deployment_stage(decision, runtime=runtime) + return _finalize_decision(decision, runtime=runtime) -def moderate(text: str, *, runtime: EffectivePolicyRuntime | None = None) -> ModerationResponse: +def moderate( + text: str, + *, + context: ModerationContext | None = None, + runtime: EffectivePolicyRuntime | None = None, +) -> ModerationResponse: start = time.perf_counter() runtime = runtime or resolve_policy_runtime() config = runtime.config matcher = get_lexicon_matcher() - decision = evaluate_text(text, matcher=matcher, config=config, runtime=runtime) + decision = evaluate_text(text, matcher=matcher, config=config, runtime=runtime, context=context) latency_ms = int((time.perf_counter() - start) * 1000) pack_versions = resolve_pack_versions(config.pack_versions) effective_model_version = resolve_runtime_model_version(config.model_version) diff --git a/src/sentinel_api/policy_config.py b/src/sentinel_api/policy_config.py index c00c5de..cb64318 100644 --- a/src/sentinel_api/policy_config.py +++ b/src/sentinel_api/policy_config.py @@ -14,6 +14,8 @@ ReasonCode, ToxicityByAction, get_policy_config, + get_runtime_phase_override, reset_policy_config_cache, resolve_policy_runtime, + set_runtime_phase_override, ) diff --git a/src/sentinel_api/rate_limit.py b/src/sentinel_api/rate_limit.py index c284100..54847bb 100644 --- a/src/sentinel_api/rate_limit.py +++ b/src/sentinel_api/rate_limit.py @@ -47,8 +47,9 @@ def _cleanup(self, bucket: deque[float], now: float) -> None: while bucket and now - bucket[0] > self.window_seconds: bucket.popleft() - def check(self, key: str) -> RateLimitDecision: + def check(self, key: str, *, cost: int = 1) -> RateLimitDecision: now = time.time() + normalized_cost = max(1, int(cost)) bucket_key = _rate_limit_bucket_key(key) bucket = self._events[bucket_key] self._cleanup(bucket, now) @@ -58,7 +59,7 @@ def check(self, key: str) -> RateLimitDecision: else: reset_after = max(1, int(self.window_seconds - (now - bucket[0]))) - if len(bucket) >= self.per_minute: + if len(bucket) + normalized_cost > self.per_minute: return RateLimitDecision( allowed=False, limit=self.per_minute, @@ -67,7 +68,8 @@ def check(self, key: str) -> RateLimitDecision: retry_after_seconds=reset_after, ) - bucket.append(now) + for _ in range(normalized_cost): + bucket.append(now) remaining = max(self.per_minute - len(bucket), 0) reset_after = max(1, int(self.window_seconds - (now - bucket[0]))) return RateLimitDecision( @@ -77,8 +79,8 @@ def check(self, key: str) -> RateLimitDecision: reset_after_seconds=reset_after, ) - def allow(self, key: str) -> bool: - return self.check(key).allowed + def allow(self, key: str, *, cost: int = 1) -> bool: + return self.check(key, cost=cost).allowed def reset(self) -> None: self._events.clear() @@ -108,21 +110,27 @@ def __init__(self, per_minute: int, storage_uri: str) -> None: self._limiter = MovingWindowRateLimiter(self._storage) self._rate_limit_item_cls: Any = RateLimitItemPerMinute - def check(self, key: str) -> RateLimitDecision: + def check(self, key: str, *, cost: int = 1) -> RateLimitDecision: # Preserve existing response contract while shifting enforcement to # distributed limits storage (Redis/memcached/etc.). now = time.time() + normalized_cost = max(1, int(cost)) normalized_key = f"{_RATE_LIMIT_KEY_PREFIX}{_rate_limit_bucket_key(key)}" item = self._rate_limit_item_cls(self.per_minute) try: - allowed = bool(self._limiter.hit(item, normalized_key)) + try: + allowed = bool(self._limiter.hit(item, normalized_key, cost=normalized_cost)) + except TypeError: # pragma: no cover - older limits versions + allowed = True + for _ in range(normalized_cost): + allowed = allowed and bool(self._limiter.hit(item, normalized_key)) window = self._limiter.get_window_stats(item, normalized_key) except Exception as exc: # pragma: no cover - network/storage failures logger.warning( "distributed rate limiter unavailable; using in-memory fallback: %s", exc, ) - return super().check(key) + return super().check(key, cost=normalized_cost) reset_after = max(1, int(window.reset_time - now)) remaining = max(0, int(window.remaining)) diff --git a/src/sentinel_api/result_cache.py b/src/sentinel_api/result_cache.py new file mode 100644 index 0000000..78e5929 --- /dev/null +++ b/src/sentinel_api/result_cache.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +import hashlib +import json +import logging +from typing import Any + +from sentinel_core.models import ModerationContext, ModerationResponse + +logger = logging.getLogger(__name__) + +CACHE_KEY_PREFIX = "sentinel:result:" + + +def make_cache_key( + text: str, + *, + policy_version: str, + lexicon_version: str, + model_version: str, + pack_versions: dict[str, str], + deployment_stage: str, + context: ModerationContext | None, +) -> str: + context_payload: dict[str, Any] + if context is None: + context_payload = {} + else: + context_payload = context.model_dump() + canonical = { + "text": text, + "policy_version": policy_version, + "lexicon_version": lexicon_version, + "model_version": model_version, + "pack_versions": dict(pack_versions), + "deployment_stage": deployment_stage, + "context": context_payload, + } + digest = hashlib.sha256( + json.dumps(canonical, sort_keys=True, ensure_ascii=True).encode("utf-8") + ).hexdigest() + return f"{CACHE_KEY_PREFIX}{digest}" + + +def get_cached_result(cache_key: str, redis_url: str) -> ModerationResponse | None: + try: + import redis + + client = redis.Redis.from_url(redis_url, decode_responses=True) + cached = client.get(cache_key) + if not cached: + return None + return ModerationResponse.model_validate_json(cached) + except Exception as exc: + logger.debug("result cache read failed: %s", exc) + return None + + +def set_cached_result( + cache_key: str, + result: ModerationResponse, + redis_url: str, + *, + ttl: int, +) -> None: + normalized_ttl = max(1, int(ttl)) + try: + import redis + + client = redis.Redis.from_url(redis_url, decode_responses=True) + client.set(cache_key, result.model_dump_json(), ex=normalized_ttl) + except Exception as exc: + logger.debug("result cache write failed: %s", exc) diff --git a/src/sentinel_core/claim_likeness.py b/src/sentinel_core/claim_likeness.py index f1cb972..3a9c93e 100644 --- a/src/sentinel_core/claim_likeness.py +++ b/src/sentinel_core/claim_likeness.py @@ -26,6 +26,25 @@ "polling", "constituency", "constituencies", + "uchaguzi", + "kura", + "matokeo", + "urais", + "bunge", + "wabunge", + "ubunge", + "seneti", + "kaunti", + "ugombea", + "wagombea", + "mgombea", + "diwani", + "kiems", + "yiero", + "ombulu", + "ker", + "kuraiyat", + "raundi", } ASSERTIVE_CLAIM_TERMS = { "is", @@ -43,6 +62,14 @@ "fraud", "fraudulent", "fake", + "ni", + "iko", + "ilitokea", + "imefanywa", + "imeonekana", + "kweli", + "hakika", + "imethibitishwa", } DISINFO_NARRATIVE_TERMS = { "rigged", @@ -52,6 +79,13 @@ "fake", "fraud", "fraudulent", + "imeibwa", + "imeporwa", + "imeharibiwa", + "bandia", + "udanganyifu", + "wizi", + "kuchakachua", } HEDGING_TERMS = { "alleged", @@ -66,6 +100,17 @@ "could", "seems", "seem", + "inadaiwa", + "pengine", + "labda", + "inasemekana", + "habari", + "ripoti", + "tetesi", + "madai", + "dai", + "huenda", + "inasadikiwa", } diff --git a/src/sentinel_core/embedding_bakeoff.py b/src/sentinel_core/embedding_bakeoff.py index 5122c79..ac78f6e 100644 --- a/src/sentinel_core/embedding_bakeoff.py +++ b/src/sentinel_core/embedding_bakeoff.py @@ -7,6 +7,7 @@ import time import unicodedata from dataclasses import dataclass +from functools import lru_cache from pathlib import Path from typing import Any, cast @@ -146,7 +147,7 @@ def _build_candidates(*, enable_optional_models: bool) -> list[BakeoffCandidate] BakeoffCandidate( candidate_id="e5-multilingual-small", display_name="multilingual-e5-small", - embedding_dim=VECTOR_DIMENSION, + embedding_dim=384, is_baseline=False, is_substitute=False, unavailable_reason=optional_reason, @@ -154,7 +155,7 @@ def _build_candidates(*, enable_optional_models: bool) -> list[BakeoffCandidate] BakeoffCandidate( candidate_id="labse", display_name="LaBSE", - embedding_dim=VECTOR_DIMENSION, + embedding_dim=384, is_baseline=False, is_substitute=False, unavailable_reason=optional_reason, @@ -176,9 +177,37 @@ def _build_candidates(*, enable_optional_models: bool) -> list[BakeoffCandidate] ] +@lru_cache(maxsize=1) +def _load_e5_small_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer("intfloat/multilingual-e5-small") + + +@lru_cache(maxsize=1) +def _load_labse_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer("sentence-transformers/LaBSE") + + +def _embed_e5_small(text: str) -> list[float]: + embedding = _load_e5_small_model().encode(f"query: {text}", normalize_embeddings=True) + return embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) + + +def _embed_labse(text: str) -> list[float]: + embedding = _load_labse_model().encode(text, normalize_embeddings=True) + return embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) + + def _embed(candidate_id: str, text: str) -> list[float]: if candidate_id == "hash-bow-v1": return embed_hash_bow_v1(text) + if candidate_id == "e5-multilingual-small": + return _embed_e5_small(text) + if candidate_id == "labse": + return _embed_labse(text) if candidate_id == "hash-token-v1": return _embed_hash_token_v1(text) if candidate_id == "hash-chargram-v1": diff --git a/src/sentinel_core/models.py b/src/sentinel_core/models.py index d5fd3d6..4c71432 100644 --- a/src/sentinel_core/models.py +++ b/src/sentinel_core/models.py @@ -34,6 +34,37 @@ class ModerationRequest(BaseModel): request_id: str | None = Field(default=None, max_length=128) +class ModerationBatchItem(BaseModel): + model_config = ConfigDict(extra="forbid") + + text: str = Field(min_length=1, max_length=5000) + context: ModerationContext | None = None + request_id: str | None = Field(default=None, max_length=128) + + +class ModerationBatchRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + items: list[ModerationBatchItem] = Field(min_length=1, max_length=50) + + +class ModerationBatchItemResult(BaseModel): + model_config = ConfigDict(extra="forbid") + + request_id: str = Field(min_length=1, max_length=128) + result: ModerationResponse | None = None + error: ErrorResponse | None = None + + +class ModerationBatchResponse(BaseModel): + model_config = ConfigDict(extra="forbid") + + items: list[ModerationBatchItemResult] + total: int = Field(ge=0) + succeeded: int = Field(ge=0) + failed: int = Field(ge=0) + + class EvidenceItem(BaseModel): model_config = ConfigDict(extra="forbid") @@ -86,3 +117,24 @@ class MetricsResponse(BaseModel): http_status_counts: dict[str, int] latency_ms_buckets: dict[str, int] validation_error_count: int = Field(ge=0) + + +class PublicAppealCreateRequest(BaseModel): + model_config = ConfigDict(extra="forbid") + + decision_request_id: str = Field(min_length=1, max_length=128) + original_action: Action + original_reason_codes: list[ReasonCode] = Field(min_length=1) + original_model_version: str = Field(min_length=1, max_length=128) + original_lexicon_version: str = Field(min_length=1, max_length=128) + original_policy_version: str = Field(min_length=1, max_length=128) + original_pack_versions: dict[str, str] = Field(min_length=1) + reason: str | None = Field(default=None, max_length=500) + + +class PublicAppealCreateResponse(BaseModel): + model_config = ConfigDict(extra="forbid") + + appeal_id: int = Field(ge=1) + status: Literal["submitted"] = "submitted" + request_id: str = Field(min_length=1, max_length=128) diff --git a/src/sentinel_core/policy_config.py b/src/sentinel_core/policy_config.py index 9bda5e7..6680a08 100644 --- a/src/sentinel_core/policy_config.py +++ b/src/sentinel_core/policy_config.py @@ -2,6 +2,7 @@ import json import os +import threading from enum import StrEnum from functools import lru_cache from pathlib import Path @@ -112,6 +113,21 @@ def reset_policy_config_cache() -> None: get_policy_config.cache_clear() +_runtime_phase_override: ElectoralPhase | None = None +_runtime_phase_override_lock = threading.Lock() + + +def set_runtime_phase_override(phase: ElectoralPhase | None) -> None: + global _runtime_phase_override + with _runtime_phase_override_lock: + _runtime_phase_override = phase + + +def get_runtime_phase_override() -> ElectoralPhase | None: + with _runtime_phase_override_lock: + return _runtime_phase_override + + @lru_cache(maxsize=1) def get_policy_config() -> PolicyConfig: path = Path(os.getenv("SENTINEL_POLICY_CONFIG_PATH", str(_default_config_path()))) @@ -121,6 +137,9 @@ def get_policy_config() -> PolicyConfig: def _resolve_effective_phase(config: PolicyConfig) -> ElectoralPhase | None: + runtime_override = get_runtime_phase_override() + if runtime_override is not None: + return runtime_override env_phase = os.getenv("SENTINEL_ELECTORAL_PHASE") if env_phase is None or not env_phase.strip(): return config.electoral_phase diff --git a/src/sentinel_db/__init__.py b/src/sentinel_db/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/src/sentinel_db/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/src/sentinel_db/pool.py b/src/sentinel_db/pool.py new file mode 100644 index 0000000..7f393e3 --- /dev/null +++ b/src/sentinel_db/pool.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +from threading import Lock +from typing import TYPE_CHECKING + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: # pragma: no cover + from psycopg_pool import ConnectionPool +else: # pragma: no cover + ConnectionPool = object # type: ignore[assignment] + +_pool: ConnectionPool | None = None +_pool_lock = Lock() +_pool_import_failed: bool = False + + +def get_pool(database_url: str) -> ConnectionPool | None: + """Return a singleton psycopg connection pool for the given DB URL. + + Pooling is optional: if `psycopg_pool` is not available, this returns None. + """ + global _pool + if _pool is not None: + return _pool + + normalized_url = database_url.strip() + if not normalized_url: + return None + + with _pool_lock: + if _pool is not None: + return _pool + try: + from psycopg_pool import ConnectionPool as _ConnectionPool + except ImportError: + global _pool_import_failed + if not _pool_import_failed: + logger.warning("psycopg_pool is not installed; DB pooling disabled") + _pool_import_failed = True + _pool = None + return None + + _pool = _ConnectionPool( + conninfo=normalized_url, + min_size=2, + max_size=10, + open=True, + ) + return _pool + + +def peek_pool() -> ConnectionPool | None: + """Return the pool if already created, otherwise None.""" + with _pool_lock: + return _pool + + +def close_pool() -> None: + global _pool + with _pool_lock: + pool = _pool + _pool = None + if pool is None: + return + try: + pool.close() + except Exception as exc: # pragma: no cover - defensive shutdown + logger.warning("failed to close DB pool cleanly: %s", exc) diff --git a/src/sentinel_langpack/__init__.py b/src/sentinel_langpack/__init__.py index 40865e8..c9ecd9b 100644 --- a/src/sentinel_langpack/__init__.py +++ b/src/sentinel_langpack/__init__.py @@ -1 +1,12 @@ -"""Language-pack boundary for pack version resolution.""" +"""Language-pack boundary for pack version resolution and hot-path matchers.""" + +from __future__ import annotations + +from sentinel_langpack.hot_path import PackMatcher, get_wave1_pack_matchers +from sentinel_langpack.registry import resolve_pack_versions + +__all__ = [ + "PackMatcher", + "get_wave1_pack_matchers", + "resolve_pack_versions", +] diff --git a/src/sentinel_langpack/hot_path.py b/src/sentinel_langpack/hot_path.py new file mode 100644 index 0000000..fd0a196 --- /dev/null +++ b/src/sentinel_langpack/hot_path.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +import json +import re +import unicodedata +from dataclasses import dataclass +from functools import lru_cache +from pathlib import Path + +from sentinel_langpack.wave1 import ( + PackLexicon, + PackLexiconEntry, + PackNormalization, + Wave1PackManifest, + load_wave1_registry, + wave1_packs_in_priority_order, +) + +WORD_BOUNDARY_CHARS = r"[0-9A-Za-zÀ-ÖØ-öø-ÿ']" +TERM_TOKEN_PATTERN = re.compile(r"[0-9A-Za-zÀ-ÖØ-öø-ÿ']+") +DEFAULT_REGISTRY_PATH = Path("data/langpacks/registry.json") + + +def _resolve_registry_path(path: str | Path) -> Path: + candidate = Path(path) + if candidate.is_absolute(): + return candidate + return (Path.cwd() / candidate).resolve() + + +def _load_json(path: Path) -> dict[str, object]: + return json.loads(path.read_text(encoding="utf-8")) + + +def _resolve_pack_root(registry_path: Path, manifest: Wave1PackManifest) -> Path: + root = registry_path.parent + pack_dir = Path(manifest.directory) + if pack_dir.is_absolute(): + return pack_dir.resolve() + return (root / pack_dir).resolve() + + +def _normalize_text(text: str, replacements: dict[str, str]) -> str: + normalized = unicodedata.normalize("NFKC", text).replace("’", "'").lower() + for source, target in replacements.items(): + source_key = source.strip().lower() + if not source_key: + continue + normalized = normalized.replace(source_key, target.strip().lower()) + return normalized + + +def _compile_term_pattern(term: str) -> re.Pattern[str]: + normalized = unicodedata.normalize("NFKC", term).replace("’", "'").lower().strip() + if not normalized: + return re.compile(r"(?!x)x") + tokens = TERM_TOKEN_PATTERN.findall(normalized) + if not tokens: + return re.compile(re.escape(normalized)) + boundary_start = rf"(? list[PackLexiconEntry]: + normalized = _normalize_text(text, self.normalization) + matches: list[PackLexiconEntry] = [] + for entry, pattern in self.compiled_entries: + if pattern.search(normalized): + matches.append(entry) + return matches + + +def _build_matcher( + manifest: Wave1PackManifest, + *, + registry_path: Path, +) -> PackMatcher: + pack_root = _resolve_pack_root(registry_path, manifest) + normalization_payload = _load_json(pack_root / manifest.artifacts.normalization) + lexicon_payload = _load_json(pack_root / manifest.artifacts.lexicon) + normalization = PackNormalization.model_validate(normalization_payload) + lexicon = PackLexicon.model_validate(lexicon_payload) + compiled_entries = [(entry, _compile_term_pattern(entry.term)) for entry in lexicon.entries] + return PackMatcher( + language=manifest.language.strip().lower(), + pack_version=manifest.pack_version, + compiled_entries=compiled_entries, + normalization=dict(normalization.replacements), + ) + + +@lru_cache(maxsize=1) +def get_wave1_pack_matchers() -> list[PackMatcher]: + """Return compiled matchers for wave1 packs. + + Falls back to `[]` on missing or invalid registry/pack artifacts so the + request path never fails due to optional pack data. + """ + registry_path = _resolve_registry_path(DEFAULT_REGISTRY_PATH) + try: + registry = load_wave1_registry(registry_path) + manifests = wave1_packs_in_priority_order(registry) + return [_build_matcher(manifest, registry_path=registry_path) for manifest in manifests] + except (FileNotFoundError, OSError, ValueError, json.JSONDecodeError): + return [] diff --git a/src/sentinel_lexicon/vector_matcher.py b/src/sentinel_lexicon/vector_matcher.py index d153d4e..19a0575 100644 --- a/src/sentinel_lexicon/vector_matcher.py +++ b/src/sentinel_lexicon/vector_matcher.py @@ -16,6 +16,8 @@ VECTOR_DIMENSION = 64 VECTOR_MODEL = "hash-bow-v1" +E5_SMALL_MODEL = "e5-multilingual-small-v1" +E5_SMALL_DIMENSION = 384 DEFAULT_VECTOR_MATCH_THRESHOLD = 0.82 DEFAULT_STATEMENT_TIMEOUT_MS = 60 TOKEN_PATTERN = re.compile(r"[0-9A-Za-zÀ-ÖØ-öø-ÿ']+") @@ -82,6 +84,25 @@ def embed_text(text: str) -> list[float]: return [value / norm for value in vector] +@lru_cache(maxsize=1) +def _load_e5_small_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer("intfloat/multilingual-e5-small") + + +def _embed_passage(text: str, *, embedding_model: str) -> tuple[int, list[float]] | None: + if embedding_model == VECTOR_MODEL: + return VECTOR_DIMENSION, embed_text(text) + if embedding_model == E5_SMALL_MODEL: + model = _load_e5_small_model() + embedding = model.encode(f"passage: {text}", normalize_embeddings=True) + values = embedding.tolist() if hasattr(embedding, "tolist") else list(embedding) + return E5_SMALL_DIMENSION, values + logger.warning("unsupported embedding_model for lexicon embeddings: %s", embedding_model) + return None + + def _vector_literal(values: list[float]) -> str: return "[" + ",".join(f"{value:.8f}" for value in values) + "]" @@ -130,22 +151,37 @@ def _get_psycopg_module(): return importlib.import_module("psycopg") +def _maybe_get_pool(database_url: str): + try: + from sentinel_db.pool import peek_pool # type: ignore[import-not-found] + except Exception: + return None + del database_url + return peek_pool() + + def _apply_statement_timeout(cur) -> None: timeout_ms = _vector_statement_timeout_ms() cur.execute(f"SET LOCAL statement_timeout = '{timeout_ms}ms'") -@lru_cache(maxsize=16) -def _ensure_embeddings_for_version(database_url: str, lexicon_version: str) -> None: +@lru_cache(maxsize=64) +def _ensure_embeddings_for_version( + database_url: str, + lexicon_version: str, + embedding_model: str, +) -> None: psycopg = _get_psycopg_module() - with psycopg.connect(database_url) as conn: + pool = _maybe_get_pool(database_url) + conn_ctx = pool.connection() if pool is not None else psycopg.connect(database_url) + with conn_ctx as conn: with conn.cursor() as cur: _apply_statement_timeout(cur) cur.execute( """ SELECT le.id, le.term FROM lexicon_entries AS le - LEFT JOIN lexicon_entry_embeddings AS emb + LEFT JOIN lexicon_entry_embeddings_v2 AS emb ON emb.lexicon_entry_id = le.id AND emb.embedding_model = %s WHERE le.status = 'active' @@ -153,27 +189,31 @@ def _ensure_embeddings_for_version(database_url: str, lexicon_version: str) -> N AND emb.lexicon_entry_id IS NULL ORDER BY le.id ASC """, - (VECTOR_MODEL, lexicon_version), + (embedding_model, lexicon_version), ) rows = cur.fetchall() for row in rows: lexicon_entry_id = int(row[0]) term = str(row[1]) - embedding_literal = _vector_literal(embed_text(term)) + embedded = _embed_passage(term, embedding_model=embedding_model) + if embedded is None: + continue + embedding_dim, embedding_values = embedded + embedding_literal = _vector_literal(embedding_values) cur.execute( """ - INSERT INTO lexicon_entry_embeddings - (lexicon_entry_id, embedding, embedding_model, updated_at) + INSERT INTO lexicon_entry_embeddings_v2 + (lexicon_entry_id, embedding, embedding_model, embedding_dim, updated_at) VALUES - (%s, %s::vector, %s, NOW()) - ON CONFLICT (lexicon_entry_id) + (%s, %s::vector, %s, %s, NOW()) + ON CONFLICT (lexicon_entry_id, embedding_model) DO UPDATE SET embedding = EXCLUDED.embedding, - embedding_model = EXCLUDED.embedding_model, + embedding_dim = EXCLUDED.embedding_dim, updated_at = NOW() """, - (lexicon_entry_id, embedding_literal, VECTOR_MODEL), + (lexicon_entry_id, embedding_literal, embedding_model, embedding_dim), ) conn.commit() @@ -186,6 +226,8 @@ def find_vector_match( text: str, *, lexicon_version: str, + query_embedding: list[float], + embedding_model: str, min_similarity: float | None = None, ) -> VectorMatch | None: if not _vector_matching_enabled(): @@ -196,15 +238,15 @@ def find_vector_match( return None try: - _ensure_embeddings_for_version(database_url, lexicon_version) + _ensure_embeddings_for_version(database_url, lexicon_version, embedding_model) except Exception as exc: logger.warning("vector embedding sync failed; falling back: %s", exc) return None - query_vector = embed_text(text) - if not any(query_vector): + del text + if not any(query_embedding): return None - query_vector_literal = _vector_literal(query_vector) + query_vector_literal = _vector_literal(query_embedding) threshold = _vector_match_threshold() if min_similarity is not None: @@ -218,7 +260,9 @@ def find_vector_match( psycopg = _get_psycopg_module() try: - with psycopg.connect(database_url) as conn: + pool = _maybe_get_pool(database_url) + conn_ctx = pool.connection() if pool is not None else psycopg.connect(database_url) + with conn_ctx as conn: with conn.cursor() as cur: _apply_statement_timeout(cur) cur.execute( @@ -233,7 +277,7 @@ def find_vector_match( le.lang, (1 - (emb.embedding <=> %s::vector))::float8 AS similarity FROM lexicon_entries AS le - JOIN lexicon_entry_embeddings AS emb + JOIN lexicon_entry_embeddings_v2 AS emb ON emb.lexicon_entry_id = le.id WHERE le.status = 'active' AND le.lexicon_version = %s @@ -245,7 +289,7 @@ def find_vector_match( ( query_vector_literal, lexicon_version, - VECTOR_MODEL, + embedding_model, query_vector_literal, ), ) diff --git a/tests/test_admin_phase.py b/tests/test_admin_phase.py new file mode 100644 index 0000000..3cf53a9 --- /dev/null +++ b/tests/test_admin_phase.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import json + +import pytest +from fastapi.testclient import TestClient + +from sentinel_api.main import app +from sentinel_core.policy_config import ( + ElectoralPhase, + get_policy_config, + resolve_policy_runtime, + set_runtime_phase_override, +) + +client = TestClient(app) + + +@pytest.fixture(autouse=True) +def reset_runtime_state(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv("SENTINEL_ELECTORAL_PHASE", raising=False) + monkeypatch.delenv("SENTINEL_OAUTH_TOKENS_JSON", raising=False) + set_runtime_phase_override(None) + yield + set_runtime_phase_override(None) + + +def _set_registry(monkeypatch: pytest.MonkeyPatch, payload: dict[str, object]) -> None: + monkeypatch.setenv("SENTINEL_OAUTH_TOKENS_JSON", json.dumps(payload)) + + +def test_runtime_override_takes_priority_over_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SENTINEL_ELECTORAL_PHASE", "campaign") + set_runtime_phase_override(ElectoralPhase.VOTING_DAY) + runtime = resolve_policy_runtime() + assert runtime.effective_phase == ElectoralPhase.VOTING_DAY + + +def test_clearing_override_falls_back_to_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SENTINEL_ELECTORAL_PHASE", "campaign") + set_runtime_phase_override(ElectoralPhase.VOTING_DAY) + set_runtime_phase_override(None) + runtime = resolve_policy_runtime() + assert runtime.effective_phase == ElectoralPhase.CAMPAIGN + + +def test_runtime_uses_config_when_no_env_and_no_override() -> None: + config = get_policy_config() + runtime = resolve_policy_runtime() + assert runtime.effective_phase == config.electoral_phase + + +def test_admin_phase_endpoint_requires_scope(monkeypatch: pytest.MonkeyPatch) -> None: + _set_registry( + monkeypatch, + { + "token-admin-read-only": { + "client_id": "admin-reader", + "scopes": ["admin:proposal:read"], + } + }, + ) + response = client.post( + "/admin/policy/phase", + headers={"Authorization": "Bearer token-admin-read-only"}, + json={"phase": "voting_day"}, + ) + assert response.status_code == 403 + + +def test_admin_phase_endpoint_sets_override(monkeypatch: pytest.MonkeyPatch) -> None: + _set_registry( + monkeypatch, + { + "token-admin-writer": { + "client_id": "admin-writer", + "scopes": ["admin:policy:write"], + } + }, + ) + response = client.post( + "/admin/policy/phase", + headers={"Authorization": "Bearer token-admin-writer"}, + json={"phase": "voting_day"}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["effective_phase"] == "voting_day" diff --git a/tests/test_api.py b/tests/test_api.py index f563945..c36ebbf 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,6 +5,7 @@ import pytest from fastapi.testclient import TestClient +from sentinel_api.appeals import get_appeals_runtime, reset_appeals_runtime_state from sentinel_api.main import app, rate_limiter from sentinel_api.metrics import metrics from sentinel_api.model_registry import ClassifierShadowResult @@ -18,6 +19,7 @@ def reset_rate_limiter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("SENTINEL_API_KEY", TEST_API_KEY) rate_limiter.reset() metrics.reset() + reset_appeals_runtime_state() def test_health() -> None: @@ -162,6 +164,128 @@ def test_rate_limit_exceeded() -> None: rate_limiter.per_minute = original +def test_batch_happy_path_two_items() -> None: + response = client.post( + "/v1/moderate/batch", + json={ + "items": [ + {"text": "We should discuss policy peacefully."}, + {"text": "This election is rigged."}, + ] + }, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["total"] == 2 + assert payload["succeeded"] == 2 + assert payload["failed"] == 0 + assert len(payload["items"]) == 2 + assert payload["items"][0]["result"] is not None + + +def test_batch_partial_failure(monkeypatch) -> None: + import sentinel_api.policy as policy + + def flaky(text: str, *, context=None, runtime=None): + if text == "boom": + raise RuntimeError("boom") + return policy.moderate(text, context=context, runtime=runtime) + + monkeypatch.setattr("sentinel_api.main.moderate", flaky) + + response = client.post( + "/v1/moderate/batch", + json={"items": [{"text": "boom"}, {"text": "We should discuss policy peacefully."}]}, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["total"] == 2 + assert payload["succeeded"] == 1 + assert payload["failed"] == 1 + assert payload["items"][0]["result"] is None + assert payload["items"][0]["error"]["error_code"] == "HTTP_500" + + +def test_batch_oversized_returns_validation_error() -> None: + response = client.post( + "/v1/moderate/batch", + json={"items": [{"text": "hello"} for _ in range(51)]}, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 400 + + +def test_batch_rate_limit_429(monkeypatch) -> None: + original = rate_limiter.per_minute + rate_limiter.per_minute = 1 + try: + monkeypatch.setattr( + "sentinel_api.main.moderate", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("should not run")), + ) + response = client.post( + "/v1/moderate/batch", + json={"items": [{"text": "a"}, {"text": "b"}]}, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 429 + assert response.headers["X-RateLimit-Limit"] == "1" + finally: + rate_limiter.per_minute = original + + +def test_batch_unauthenticated_401() -> None: + response = client.post( + "/v1/moderate/batch", + json={"items": [{"text": "hello"}]}, + ) + assert response.status_code == 401 + + +def test_moderate_uses_embedding_provider_via_env(monkeypatch) -> None: + monkeypatch.setenv("SENTINEL_EMBEDDING_PROVIDER", "e5-multilingual-small-v1") + + captured: dict[str, object] = {} + + class _Runtime: + embedding_provider_id = "e5-multilingual-small-v1" + + class _Provider: + def embed(self, _text: str, *, timeout_ms: int): # type: ignore[no-untyped-def] + del timeout_ms + return [0.0] * 384 + + embedding_provider = _Provider() + + def _fake_find_vector_match( + _text: str, + *, + lexicon_version: str, + query_embedding: list[float], + embedding_model: str, + min_similarity=None, + ): + del lexicon_version, min_similarity + captured["embedding_model"] = embedding_model + captured["embedding_dim"] = len(query_embedding) + return None + + monkeypatch.setattr("sentinel_api.policy._vector_matching_configured", lambda: True) + monkeypatch.setattr("sentinel_api.policy.get_model_runtime", lambda: _Runtime()) + monkeypatch.setattr("sentinel_api.policy.find_vector_match", _fake_find_vector_match) + + response = client.post( + "/v1/moderate", + json={"text": "peaceful civic dialogue"}, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 200 + assert captured["embedding_model"] == "e5-multilingual-small-v1" + assert captured["embedding_dim"] == 384 + + def test_moderate_internal_error_returns_structured_500(monkeypatch) -> None: def broken(_text: str, *, runtime=None): del runtime @@ -200,6 +324,87 @@ def _unexpected_shadow_call(text: str): assert response.status_code == 200 +def test_public_appeal_happy_path() -> None: + decision_request_id = "req-abc123" + response = client.post( + "/v1/appeals", + json={ + "decision_request_id": decision_request_id, + "original_action": "REVIEW", + "original_reason_codes": ["R_DOGWHISTLE_CONTEXT_REQUIRED"], + "original_model_version": "sentinel-multi-v2", + "original_lexicon_version": "hatelex-v2.1", + "original_policy_version": "policy-2026.11", + "original_pack_versions": {"en": "pack-en-0.1"}, + "reason": "false positive", + }, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 201 + payload = response.json() + assert payload["status"] == "submitted" + assert payload["request_id"] == decision_request_id + assert isinstance(payload["appeal_id"], int) + + runtime = get_appeals_runtime() + listed = runtime.list_appeals(status="submitted", request_id=decision_request_id, limit=10) + assert listed.items + appeal = listed.items[0] + assert appeal.submitted_by == "public-api" + assert appeal.request_id == decision_request_id + assert appeal.original_decision_id == decision_request_id + + +def test_public_appeal_missing_provenance_validation_error() -> None: + response = client.post( + "/v1/appeals", + json={ + "decision_request_id": "req-abc123", + "original_action": "REVIEW", + "original_reason_codes": ["R_DOGWHISTLE_CONTEXT_REQUIRED"], + "original_lexicon_version": "hatelex-v2.1", + "original_policy_version": "policy-2026.11", + "original_pack_versions": {"en": "pack-en-0.1"}, + }, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 400 + assert response.json()["error_code"] == "HTTP_400" + + +def test_public_appeal_empty_original_model_version_validation_error() -> None: + response = client.post( + "/v1/appeals", + json={ + "decision_request_id": "req-abc123", + "original_action": "REVIEW", + "original_reason_codes": ["R_DOGWHISTLE_CONTEXT_REQUIRED"], + "original_model_version": "", + "original_lexicon_version": "hatelex-v2.1", + "original_policy_version": "policy-2026.11", + "original_pack_versions": {"en": "pack-en-0.1"}, + }, + headers={"X-API-Key": TEST_API_KEY}, + ) + assert response.status_code == 400 + + +def test_public_appeal_unauthenticated_401() -> None: + response = client.post( + "/v1/appeals", + json={ + "decision_request_id": "req-abc123", + "original_action": "REVIEW", + "original_reason_codes": ["R_DOGWHISTLE_CONTEXT_REQUIRED"], + "original_model_version": "sentinel-multi-v2", + "original_lexicon_version": "hatelex-v2.1", + "original_policy_version": "policy-2026.11", + "original_pack_versions": {"en": "pack-en-0.1"}, + }, + ) + assert response.status_code == 401 + + def test_classifier_shadow_records_metrics_and_persistence(monkeypatch, tmp_path) -> None: shadow_path = tmp_path / "shadow_predictions.jsonl" monkeypatch.setenv("SENTINEL_DEPLOYMENT_STAGE", "advisory") diff --git a/tests/test_appeals_postgres_integration.py b/tests/test_appeals_postgres_integration.py index f87e48e..4d49540 100644 --- a/tests/test_appeals_postgres_integration.py +++ b/tests/test_appeals_postgres_integration.py @@ -1,6 +1,9 @@ from __future__ import annotations +import importlib +import json import os +from pathlib import Path from uuid import uuid4 import pytest @@ -90,3 +93,159 @@ def test_postgres_appeal_flow_round_trip(monkeypatch: pytest.MonkeyPatch) -> Non "in_review", "resolved_modified", ] + + +@pytest.mark.skipif( + not _integration_db_url(), + reason="SENTINEL_INTEGRATION_DB_URL is required for postgres integration tests", +) +def test_reversed_appeal_creates_draft_proposal(monkeypatch: pytest.MonkeyPatch) -> None: + db_url = _integration_db_url() + assert db_url is not None + monkeypatch.setenv("SENTINEL_DATABASE_URL", db_url) + runtime = get_appeals_runtime() + + suffix = uuid4().hex[:10] + created = runtime.create_appeal( + AdminAppealCreateRequest( + original_decision_id=f"decision-{suffix}", + request_id=f"request-{suffix}", + original_action="REVIEW", + original_reason_codes=["R_DISINFO_NARRATIVE_SIMILARITY"], + original_model_version="sentinel-multi-v2", + original_lexicon_version="hatelex-v2.1", + original_policy_version="policy-2026.11", + original_pack_versions={"en": "pack-en-0.1"}, + rationale="proposal integration appeal", + ), + submitted_by="integration-suite", + ) + resolved = runtime.transition_appeal( + appeal_id=created.id, + payload=AdminAppealTransitionRequest( + to_status="resolved_reversed", + rationale="reversed", + resolution_code="APPEAL_REVERSED", + resolution_reason_codes=["R_ALLOW_NO_POLICY_MATCH"], + ), + actor="integration-reviewer", + ) + assert resolved.status == "resolved_reversed" + + psycopg = importlib.import_module("psycopg") + with psycopg.connect(db_url) as conn: + with conn.cursor() as cur: + cur.execute( + """ + SELECT title, status, proposal_type, evidence + FROM release_proposals + WHERE proposed_by = %s + ORDER BY id DESC + LIMIT 5 + """, + ("integration-reviewer",), + ) + rows = cur.fetchall() + assert any( + row[1] == "draft" + and row[2] == "lexicon" + and str(resolved.id) in str(row[0]) + and str(resolved.resolution_code) in str(row[0]) + for row in rows + ) + + +@pytest.mark.skipif( + not _integration_db_url(), + reason="SENTINEL_INTEGRATION_DB_URL is required for postgres integration tests", +) +def test_upheld_appeal_no_proposal(monkeypatch: pytest.MonkeyPatch) -> None: + db_url = _integration_db_url() + assert db_url is not None + monkeypatch.setenv("SENTINEL_DATABASE_URL", db_url) + runtime = get_appeals_runtime() + + suffix = uuid4().hex[:10] + created = runtime.create_appeal( + AdminAppealCreateRequest( + original_decision_id=f"decision-{suffix}", + request_id=f"request-{suffix}", + original_action="REVIEW", + original_reason_codes=["R_DISINFO_NARRATIVE_SIMILARITY"], + original_model_version="sentinel-multi-v2", + original_lexicon_version="hatelex-v2.1", + original_policy_version="policy-2026.11", + original_pack_versions={"en": "pack-en-0.1"}, + rationale="upheld integration appeal", + ), + submitted_by="integration-suite", + ) + resolved = runtime.transition_appeal( + appeal_id=created.id, + payload=AdminAppealTransitionRequest( + to_status="resolved_upheld", + rationale="upheld", + resolution_code="APPEAL_UPHELD", + ), + actor="integration-reviewer", + ) + assert resolved.status == "resolved_upheld" + + psycopg = importlib.import_module("psycopg") + with psycopg.connect(db_url) as conn: + with conn.cursor() as cur: + cur.execute( + "SELECT COUNT(1) FROM release_proposals WHERE title LIKE %s", + (f"%appeal #{created.id}:%",), + ) + row = cur.fetchone() + assert row is not None + assert int(row[0]) == 0 + + +@pytest.mark.skipif( + not _integration_db_url(), + reason="SENTINEL_INTEGRATION_DB_URL is required for postgres integration tests", +) +def test_training_sample_written_on_reversal( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + db_url = _integration_db_url() + assert db_url is not None + monkeypatch.setenv("SENTINEL_DATABASE_URL", db_url) + out_path = tmp_path / "train.jsonl" + monkeypatch.setenv("SENTINEL_TRAINING_DATA_PATH", str(out_path)) + runtime = get_appeals_runtime() + + suffix = uuid4().hex[:10] + created = runtime.create_appeal( + AdminAppealCreateRequest( + original_decision_id=f"decision-{suffix}", + request_id=f"request-{suffix}", + original_action="REVIEW", + original_reason_codes=["R_DISINFO_NARRATIVE_SIMILARITY"], + original_model_version="sentinel-multi-v2", + original_lexicon_version="hatelex-v2.1", + original_policy_version="policy-2026.11", + original_pack_versions={"en": "pack-en-0.1"}, + rationale="training integration appeal", + ), + submitted_by="integration-suite", + ) + resolved = runtime.transition_appeal( + appeal_id=created.id, + payload=AdminAppealTransitionRequest( + to_status="resolved_modified", + rationale="modified", + resolution_code="APPEAL_MODIFIED", + resolution_reason_codes=["R_ALLOW_NO_POLICY_MATCH"], + ), + actor="integration-reviewer", + ) + assert resolved.status == "resolved_modified" + assert out_path.exists() + lines = [line for line in out_path.read_text(encoding="utf-8").splitlines() if line.strip()] + assert len(lines) == 1 + record = json.loads(lines[0]) + assert record["appeal_id"] == created.id + assert "text" not in record diff --git a/tests/test_audit_events.py b/tests/test_audit_events.py new file mode 100644 index 0000000..7c37285 --- /dev/null +++ b/tests/test_audit_events.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import json + +import pytest + +from sentinel_api.audit_events import ( + AUDIT_RING_BUFFER_SIZE, + AuditEvent, + events_since, + publish_audit_event, + reset_audit_events_state, +) +from sentinel_api.main import _generate_audit_sse + + +def setup_function() -> None: + reset_audit_events_state() + + +def teardown_function() -> None: + reset_audit_events_state() + + +def test_events_since_respects_cursor() -> None: + publish_audit_event( + AuditEvent( + timestamp="2026-01-01T00:00:00Z", + action="ALLOW", + labels=["BENIGN_POLITICAL_SPEECH"], + reason_codes=["R_ALLOW_NO_POLICY_MATCH"], + latency_ms=12, + deployment_stage="supervised", + lexicon_version="lex-0", + policy_version="policy-0", + ) + ) + publish_audit_event( + AuditEvent( + timestamp="2026-01-01T00:00:01Z", + action="REVIEW", + labels=["ELECTION_INTERFERENCE"], + reason_codes=["R_ELECTION_CLAIM_MATCH"], + latency_ms=25, + deployment_stage="supervised", + lexicon_version="lex-0", + policy_version="policy-0", + ) + ) + + events, cursor = events_since(0) + assert cursor == 2 + assert [event.action for event in events] == ["ALLOW", "REVIEW"] + + events, cursor = events_since(1) + assert cursor == 2 + assert [event.action for event in events] == ["REVIEW"] + + events, cursor = events_since(2) + assert cursor == 2 + assert events == [] + + +def test_ring_buffer_drops_oldest_events() -> None: + for index in range(AUDIT_RING_BUFFER_SIZE + 5): + publish_audit_event( + AuditEvent( + timestamp=f"2026-01-01T00:00:{index:02d}Z", + action="ALLOW", + labels=[], + reason_codes=[], + latency_ms=1, + deployment_stage="supervised", + lexicon_version="lex-0", + policy_version=f"policy-{index}", + ) + ) + + events, cursor = events_since(0) + assert cursor == AUDIT_RING_BUFFER_SIZE + 5 + assert len(events) == AUDIT_RING_BUFFER_SIZE + assert events[0].policy_version == "policy-5" + + +@pytest.mark.anyio +async def test_generate_audit_sse_emits_data_lines() -> None: + publish_audit_event( + AuditEvent( + timestamp="2026-02-20T00:00:00+00:00", + action="ALLOW", + labels=["BENIGN_POLITICAL_SPEECH"], + reason_codes=["R_ALLOW_NO_POLICY_MATCH"], + latency_ms=5, + deployment_stage="supervised", + lexicon_version="lex-0", + policy_version="policy-0", + ) + ) + + generator = _generate_audit_sse(0) + chunk = await anext(generator) + assert chunk.startswith("data: ") + payload = json.loads(chunk.removeprefix("data: ").strip()) + assert payload["action"] == "ALLOW" + await generator.aclose() + + +@pytest.mark.anyio +async def test_generate_audit_sse_awaits_sleep_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: + sleep_called = False + + async def _fake_sleep(_seconds: float) -> None: + nonlocal sleep_called + sleep_called = True + publish_audit_event( + AuditEvent( + timestamp="2026-02-20T00:00:00+00:00", + action="ALLOW", + labels=[], + reason_codes=[], + latency_ms=1, + deployment_stage="supervised", + lexicon_version="lex-0", + policy_version="policy-0", + ) + ) + + monkeypatch.setattr("sentinel_api.main.asyncio.sleep", _fake_sleep) + + generator = _generate_audit_sse(0) + _chunk = await anext(generator) + assert sleep_called is True + await generator.aclose() diff --git a/tests/test_claim_likeness.py b/tests/test_claim_likeness.py index aad4479..fa83573 100644 --- a/tests/test_claim_likeness.py +++ b/tests/test_claim_likeness.py @@ -1,6 +1,6 @@ from __future__ import annotations -from sentinel_core.claim_likeness import assess_claim_likeness +from sentinel_core.claim_likeness import assess_claim_likeness, contains_election_anchor def test_claim_likeness_high_band_for_assertive_election_claim() -> None: @@ -41,3 +41,33 @@ def test_claim_likeness_respects_threshold_overrides() -> None: high_threshold=0.95, ) assert assessment.band == "low" + + +def test_swahili_election_anchor_detected() -> None: + assert contains_election_anchor("Uchaguzi uliibwa.") is True + assert contains_election_anchor("yiero mar siasa") is True + assert contains_election_anchor("kuraiyat") is True + + +def test_swahili_disinfo_narrative_scores_high() -> None: + assessment = assess_claim_likeness( + "matokeo ni bandia", + medium_threshold=0.4, + high_threshold=0.7, + ) + assert assessment.has_election_anchor is True + assert assessment.band in {"medium", "high"} + + +def test_swahili_hedging_reduces_score() -> None: + direct = assess_claim_likeness( + "matokeo ni bandia", + medium_threshold=0.4, + high_threshold=0.7, + ) + hedged = assess_claim_likeness( + "tetesi inadaiwa matokeo ni bandia", + medium_threshold=0.4, + high_threshold=0.7, + ) + assert hedged.score < direct.score diff --git a/tests/test_embedding_bakeoff.py b/tests/test_embedding_bakeoff.py index 29e9b7e..d4d5f3c 100644 --- a/tests/test_embedding_bakeoff.py +++ b/tests/test_embedding_bakeoff.py @@ -1,10 +1,13 @@ from __future__ import annotations import json +import sys +import types from pathlib import Path import pytest +import sentinel_core.embedding_bakeoff as bakeoff from sentinel_core.embedding_bakeoff import run_embedding_bakeoff @@ -104,3 +107,54 @@ def test_invalid_similarity_threshold_raises(tmp_path: Path) -> None: similarity_threshold=1.5, enable_optional_models=False, ) + + +def test_embedding_dim_is_384_for_e5_and_labse() -> None: + candidates = bakeoff._build_candidates(enable_optional_models=True) + dim_map = {candidate.candidate_id: candidate.embedding_dim for candidate in candidates} + assert dim_map["e5-multilingual-small"] == 384 + assert dim_map["labse"] == 384 + + +def test_e5_small_returns_384_floats(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeArray(list): + def tolist(self): # type: ignore[override] + return list(self) + + class _FakeSentenceTransformer: + def __init__(self, name: str) -> None: + assert name == "intfloat/multilingual-e5-small" + + def encode(self, text: str, *, normalize_embeddings: bool): + assert normalize_embeddings is True + assert text.startswith("query: ") + return _FakeArray([0.0] * 384) + + fake_module = types.SimpleNamespace(SentenceTransformer=_FakeSentenceTransformer) + monkeypatch.setitem(sys.modules, "sentence_transformers", fake_module) + bakeoff._load_e5_small_model.cache_clear() + + embedding = bakeoff._embed("e5-multilingual-small", "test") + assert len(embedding) == 384 + + +def test_labse_returns_384_floats(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeArray(list): + def tolist(self): # type: ignore[override] + return list(self) + + class _FakeSentenceTransformer: + def __init__(self, name: str) -> None: + assert name == "sentence-transformers/LaBSE" + + def encode(self, text: str, *, normalize_embeddings: bool): + assert normalize_embeddings is True + assert text == "test" + return _FakeArray([0.0] * 384) + + fake_module = types.SimpleNamespace(SentenceTransformer=_FakeSentenceTransformer) + monkeypatch.setitem(sys.modules, "sentence_transformers", fake_module) + bakeoff._load_labse_model.cache_clear() + + embedding = bakeoff._embed("labse", "test") + assert len(embedding) == 384 diff --git a/tests/test_health.py b/tests/test_health.py new file mode 100644 index 0000000..09a10be --- /dev/null +++ b/tests/test_health.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from sentinel_api.main import app + +client = TestClient(app) + + +def test_live_always_200() -> None: + response = client.get("/health/live") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} + + +def test_ready_200_with_lexicon(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SENTINEL_DATABASE_URL", raising=False) + monkeypatch.delenv("SENTINEL_REDIS_URL", raising=False) + response = client.get("/health/ready") + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ready" + assert payload["checks"]["lexicon"] in {"ok", "empty"} + + +def test_ready_503_when_db_unreachable(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://invalid") + monkeypatch.setattr("sentinel_api.main._check_db_ready", lambda _url: "error") + response = client.get("/health/ready") + assert response.status_code == 503 + payload = response.json() + assert payload["status"] == "degraded" + assert payload["checks"]["db"] == "error" + + +def test_ready_200_no_db_configured(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SENTINEL_DATABASE_URL", raising=False) + monkeypatch.delenv("SENTINEL_REDIS_URL", raising=False) + response = client.get("/health/ready") + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ready" + assert "db" not in payload["checks"] + + +def test_existing_health_unchanged() -> None: + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok"} diff --git a/tests/test_internal_admin_oauth.py b/tests/test_internal_admin_oauth.py index f01917e..a212057 100644 --- a/tests/test_internal_admin_oauth.py +++ b/tests/test_internal_admin_oauth.py @@ -7,8 +7,10 @@ import pytest from fastapi.testclient import TestClient +from sentinel_api.audit_events import reset_audit_events_state from sentinel_api.main import app from sentinel_api.metrics import metrics +from sentinel_core.policy_config import set_runtime_phase_override client = TestClient(app) @@ -16,7 +18,10 @@ @pytest.fixture(autouse=True) def reset_metrics(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("SENTINEL_DATABASE_URL", raising=False) + monkeypatch.delenv("SENTINEL_ELECTORAL_PHASE", raising=False) metrics.reset() + set_runtime_phase_override(None) + reset_audit_events_state() def _set_registry(monkeypatch: pytest.MonkeyPatch, payload: dict[str, object]) -> None: @@ -198,6 +203,79 @@ def test_transparency_export_allows_export_scope(monkeypatch: pytest.MonkeyPatch assert "records" in payload +def test_admin_policy_phase_requires_scope(monkeypatch: pytest.MonkeyPatch) -> None: + _set_registry( + monkeypatch, + { + "token-proposal-reader": { + "client_id": "proposal-reader", + "scopes": ["admin:proposal:read"], + } + }, + ) + response = client.post( + "/admin/policy/phase", + headers={"Authorization": "Bearer token-proposal-reader"}, + json={"phase": "voting_day"}, + ) + assert response.status_code == 403 + payload = response.json() + assert payload["error_code"] == "HTTP_403" + assert "admin:policy:write" in payload["message"] + + +def test_admin_policy_phase_update_sets_override(monkeypatch: pytest.MonkeyPatch) -> None: + _set_registry( + monkeypatch, + { + "token-policy-writer": { + "client_id": "policy-writer", + "scopes": ["admin:policy:write"], + } + }, + ) + monkeypatch.setenv("SENTINEL_ELECTORAL_PHASE", "campaign") + + response = client.post( + "/admin/policy/phase", + headers={"Authorization": "Bearer token-policy-writer"}, + json={"phase": "voting_day"}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["effective_phase"] == "voting_day" + assert payload["actor"] == "policy-writer" + + response = client.post( + "/admin/policy/phase", + headers={"Authorization": "Bearer token-policy-writer"}, + json={"phase": None}, + ) + assert response.status_code == 200 + payload = response.json() + assert payload["effective_phase"] == "campaign" + + +def test_admin_audit_stream_requires_scope(monkeypatch: pytest.MonkeyPatch) -> None: + _set_registry( + monkeypatch, + { + "token-appeal-reader": { + "client_id": "appeal-reader", + "scopes": ["admin:appeal:read"], + } + }, + ) + response = client.get( + "/admin/audit/stream", + headers={"Authorization": "Bearer token-appeal-reader"}, + ) + assert response.status_code == 403 + payload = response.json() + assert payload["error_code"] == "HTTP_403" + assert "admin:transparency:read" in payload["message"] + + def test_internal_queue_metrics_accepts_valid_jwt(monkeypatch: pytest.MonkeyPatch) -> None: secret = "test-secret-which-is-long-enough-32+" monkeypatch.setenv("SENTINEL_OAUTH_JWT_SECRET", secret) diff --git a/tests/test_langpack_hot_path.py b/tests/test_langpack_hot_path.py new file mode 100644 index 0000000..15c797b --- /dev/null +++ b/tests/test_langpack_hot_path.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from pathlib import Path + +import sentinel_api.policy as policy +from sentinel_api.policy_config import get_policy_config, reset_policy_config_cache +from sentinel_langpack import get_wave1_pack_matchers + + +class _Matcher: + version = "hatelex-v2.1" + entries = [] + + def match(self, _text: str): # type: ignore[no-untyped-def] + return [] + + +def setup_function() -> None: + reset_policy_config_cache() + + +def teardown_function() -> None: + reset_policy_config_cache() + + +def test_luo_term_routes_to_review(monkeypatch) -> None: + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "find_vector_match", lambda *_args, **_kwargs: None) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: None) + + decision = policy.evaluate_text( + "this contains chok-ruok", + matcher=_Matcher(), + config=get_policy_config(), + ) + assert decision.action == "REVIEW" + assert decision.evidence + assert decision.evidence[0].type == "lexicon" + assert decision.evidence[0].lang == "luo" + assert decision.reason_codes == ["R_LUO_INCITE_LEXICON"] + + +def test_kalenjin_term_routes_to_review(monkeypatch) -> None: + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "find_vector_match", lambda *_args, **_kwargs: None) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: None) + + decision = policy.evaluate_text( + "met-incite should match after normalization", + matcher=_Matcher(), + config=get_policy_config(), + ) + assert decision.action == "REVIEW" + assert decision.evidence + assert decision.evidence[0].type == "lexicon" + assert decision.evidence[0].lang == "kalenjin" + assert decision.reason_codes == ["R_KLN_INCITE_LEXICON"] + + +def test_pack_match_never_produces_block(monkeypatch) -> None: + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "find_vector_match", lambda *_args, **_kwargs: None) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: None) + + decision = policy.evaluate_text( + "jodak-slur appears here", + matcher=_Matcher(), + config=get_policy_config(), + ) + assert decision.action != "BLOCK" + + +def test_missing_registry_returns_empty_matchers(monkeypatch, tmp_path: Path) -> None: + import sentinel_langpack.hot_path as hot_path + + hot_path.get_wave1_pack_matchers.cache_clear() + monkeypatch.setattr(hot_path, "DEFAULT_REGISTRY_PATH", tmp_path / "missing.json") + hot_path.get_wave1_pack_matchers.cache_clear() + assert get_wave1_pack_matchers() == [] diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 67ed641..7655d74 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -1,6 +1,8 @@ from __future__ import annotations import logging +import sys +import types import sentinel_api.model_registry as model_registry @@ -64,6 +66,57 @@ def _raise_embed_text(_text: str): assert "embedding provider failed" in caplog.text +def test_e5_provider_selected_via_env(monkeypatch) -> None: + monkeypatch.setenv(model_registry.EMBEDDING_PROVIDER_ENV, "e5-multilingual-small-v1") + model_registry.reset_model_runtime_cache() + runtime = model_registry.get_model_runtime() + assert runtime.embedding_provider_id == "e5-multilingual-small-v1" + + +def test_e5_embed_returns_384_floats(monkeypatch) -> None: + class _FakeArray(list): + def tolist(self): # type: ignore[override] + return list(self) + + class _FakeModel: + def encode(self, text: str, *, normalize_embeddings: bool): + assert normalize_embeddings is True + assert text.startswith("query: ") + return _FakeArray([0.0] * 384) + + class _FakeSentenceTransformer: + def __init__(self, name: str) -> None: + assert name == "intfloat/multilingual-e5-small" + + def encode(self, text: str, *, normalize_embeddings: bool): + return _FakeModel().encode(text, normalize_embeddings=normalize_embeddings) + + fake_module = types.SimpleNamespace(SentenceTransformer=_FakeSentenceTransformer) + monkeypatch.setitem(sys.modules, "sentence_transformers", fake_module) + model_registry.E5MultilingualSmallEmbeddingProvider._load_model.cache_clear() + + monkeypatch.setenv(model_registry.EMBEDDING_PROVIDER_ENV, "e5-multilingual-small-v1") + model_registry.reset_model_runtime_cache() + runtime = model_registry.get_model_runtime() + provider = runtime.embedding_provider + embedding = provider.embed("test", timeout_ms=50) + assert embedding is not None + assert len(embedding) == 384 + + +def test_e5_graceful_when_sentence_transformers_missing(monkeypatch, caplog) -> None: + monkeypatch.setenv(model_registry.EMBEDDING_PROVIDER_ENV, "e5-multilingual-small-v1") + model_registry.reset_model_runtime_cache() + model_registry.E5MultilingualSmallEmbeddingProvider._load_model.cache_clear() + monkeypatch.delitem(sys.modules, "sentence_transformers", raising=False) + + provider = model_registry.E5MultilingualSmallEmbeddingProvider() + with caplog.at_level(logging.WARNING): + result = provider.embed("sample", timeout_ms=10) + assert result is None + assert "sentence-transformers not installed" in caplog.text + + def test_predict_classifier_shadow_drops_unknown_and_low_scores(monkeypatch) -> None: class _Classifier: name = "test" diff --git a/tests/test_policy_context_aware.py b/tests/test_policy_context_aware.py new file mode 100644 index 0000000..00f3b45 --- /dev/null +++ b/tests/test_policy_context_aware.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import sentinel_api.policy as policy +from sentinel_api.lexicon_repository import LexiconEntry +from sentinel_api.policy_config import reset_policy_config_cache, resolve_policy_runtime +from sentinel_api.vector_matcher import VectorMatch +from sentinel_core.models import ModerationContext + + +class _Matcher: + version = "hatelex-v2.1" + entries = [] + + def match(self, _text: str): # type: ignore[no-untyped-def] + return [] + + +def setup_function() -> None: + reset_policy_config_cache() + + +def teardown_function() -> None: + reset_policy_config_cache() + + +def test_forward_channel_lowers_threshold(monkeypatch) -> None: + vector_entry = LexiconEntry( + term="rigged", + action="REVIEW", + label="DISINFO_RISK", + reason_code="R_DISINFO_NARRATIVE_SIMILARITY", + severity=1, + lang="en", + ) + + def _vector_stub( + _text: str, + *, + lexicon_version: str, + query_embedding: list[float], + embedding_model: str, + min_similarity: float | None = None, + ): + assert lexicon_version == "hatelex-v2.1" + assert embedding_model + assert query_embedding + if min_similarity is None: + return None + if min_similarity <= 0.80: + return VectorMatch(entry=vector_entry, similarity=0.80, match_id="55") + return None + + runtime = resolve_policy_runtime().model_copy(update={"vector_match_threshold": 0.82}) + + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "get_wave1_pack_matchers", lambda: []) + monkeypatch.setattr(policy, "find_vector_match", _vector_stub) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: None) + + forward_context = ModerationContext(channel="forward") + decision = policy.evaluate_text( + "they manipulated election tallies", + matcher=_Matcher(), + runtime=runtime, + context=forward_context, + ) + assert decision.action == "REVIEW" + assert decision.evidence[0].type == "vector_match" + + neutral_decision = policy.evaluate_text( + "they manipulated election tallies", + matcher=_Matcher(), + runtime=runtime, + context=None, + ) + assert neutral_decision.action == "ALLOW" + + +def test_broadcast_channel_raises_threshold(monkeypatch) -> None: + vector_entry = LexiconEntry( + term="rigged", + action="REVIEW", + label="DISINFO_RISK", + reason_code="R_DISINFO_NARRATIVE_SIMILARITY", + severity=1, + lang="en", + ) + + def _vector_stub( + _text: str, + *, + lexicon_version: str, + query_embedding: list[float], + embedding_model: str, + min_similarity: float | None = None, + ): + del lexicon_version + assert embedding_model + assert query_embedding + if min_similarity is None: + return None + if min_similarity <= 0.80: + return VectorMatch(entry=vector_entry, similarity=0.80, match_id="55") + return None + + runtime = resolve_policy_runtime().model_copy(update={"vector_match_threshold": 0.82}) + + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "get_wave1_pack_matchers", lambda: []) + monkeypatch.setattr(policy, "find_vector_match", _vector_stub) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: None) + + broadcast_context = ModerationContext(channel="broadcast") + decision = policy.evaluate_text( + "they manipulated election tallies", + matcher=_Matcher(), + runtime=runtime, + context=broadcast_context, + ) + assert decision.action == "ALLOW" + + +def test_partner_factcheck_amplifies_claim_score(monkeypatch) -> None: + runtime = resolve_policy_runtime() + + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "get_wave1_pack_matchers", lambda: []) + monkeypatch.setattr(policy, "find_vector_match", lambda *_args, **_kwargs: None) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: (0.44, "x")) + + baseline = policy.evaluate_text( + "election results were manipulated", + matcher=_Matcher(), + runtime=runtime, + context=ModerationContext(source="other"), + ) + assert baseline.action == "ALLOW" + + amplified = policy.evaluate_text( + "election results were manipulated", + matcher=_Matcher(), + runtime=runtime, + context=ModerationContext(source="partner_factcheck"), + ) + assert amplified.action == "REVIEW" + assert amplified.reason_codes == ["R_DISINFO_CLAIM_LIKENESS_MEDIUM"] + + +def test_null_context_no_change(monkeypatch) -> None: + runtime = resolve_policy_runtime().model_copy(update={"vector_match_threshold": 0.82}) + + called: list[float] = [] + + def _vector_stub( + _text: str, + *, + lexicon_version: str, + query_embedding: list[float], + embedding_model: str, + min_similarity: float | None = None, + ): + del lexicon_version, query_embedding, embedding_model + called.append(min_similarity if min_similarity is not None else -1.0) + return None + + monkeypatch.setattr(policy, "find_hot_trigger_matches", lambda *_args, **_kwargs: []) + monkeypatch.setattr(policy, "get_wave1_pack_matchers", lambda: []) + monkeypatch.setattr(policy, "find_vector_match", _vector_stub) + monkeypatch.setattr(policy, "score_claim_with_fallback", lambda *_args, **_kwargs: None) + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + + policy.evaluate_text( + "peaceful discussion", + matcher=_Matcher(), + runtime=runtime, + context=None, + ) + policy.evaluate_text( + "peaceful discussion", + matcher=_Matcher(), + runtime=runtime, + context=ModerationContext(channel="unknown"), + ) + assert called == [0.82, 0.82] diff --git a/tests/test_policy_phase_modes.py b/tests/test_policy_phase_modes.py index 84442fb..35a867d 100644 --- a/tests/test_policy_phase_modes.py +++ b/tests/test_policy_phase_modes.py @@ -60,12 +60,32 @@ def test_silence_period_escalates_no_match_to_review(monkeypatch) -> None: def test_phase_override_passes_vector_threshold_to_matcher(monkeypatch) -> None: captured: dict[str, float | None] = {"threshold": None} - def _fake_find_vector_match(_text: str, *, lexicon_version: str, min_similarity=None): + class _Runtime: + embedding_provider_id = "hash-bow-v1" + + class _Provider: + def embed(self, _text: str, *, timeout_ms: int): # type: ignore[no-untyped-def] + del timeout_ms + return [0.1] * 64 + + embedding_provider = _Provider() + + def _fake_find_vector_match( + _text: str, + *, + lexicon_version: str, + query_embedding: list[float], + embedding_model: str, + min_similarity=None, + ): del lexicon_version + del query_embedding, embedding_model captured["threshold"] = min_similarity return None monkeypatch.setenv("SENTINEL_ELECTORAL_PHASE", "voting_day") + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + monkeypatch.setattr(policy, "get_model_runtime", lambda: _Runtime()) monkeypatch.setattr(policy, "find_vector_match", _fake_find_vector_match) decision = policy.evaluate_text("peaceful discussion") assert decision.action == "REVIEW" diff --git a/tests/test_policy_runtime_config.py b/tests/test_policy_runtime_config.py index a01ce41..a375fc7 100644 --- a/tests/test_policy_runtime_config.py +++ b/tests/test_policy_runtime_config.py @@ -26,7 +26,7 @@ def test_moderation_uses_external_policy_config(tmp_path, monkeypatch) -> None: result = moderate("peaceful discussion") assert result.policy_version == "policy-2099.01" assert result.model_version == "sentinel-multi-custom" - assert result.toxicity == 0.01 + assert result.toxicity == 0.366 assert result.pack_versions["en"] == "pack-en-9.9" reset_policy_config_cache() diff --git a/tests/test_policy_toxicity.py b/tests/test_policy_toxicity.py new file mode 100644 index 0000000..21f4af2 --- /dev/null +++ b/tests/test_policy_toxicity.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from sentinel_api.policy import Decision, _derive_toxicity +from sentinel_core.models import EvidenceItem + + +def test_toxicity_blended_with_model_confidence() -> None: + runtime = SimpleNamespace( + toxicity_by_action=SimpleNamespace(BLOCK=0.9, REVIEW=0.45, ALLOW=0.05) + ) + decision = Decision( + action="REVIEW", + labels=[], + reason_codes=[], + evidence=[EvidenceItem(type="model_span", span="...", confidence=0.9)], + toxicity=0.45, + ) + assert _derive_toxicity(decision, runtime=runtime) == 0.63 + + +def test_toxicity_unchanged_without_model_evidence() -> None: + runtime = SimpleNamespace( + toxicity_by_action=SimpleNamespace(BLOCK=0.9, REVIEW=0.45, ALLOW=0.05) + ) + decision = Decision( + action="REVIEW", + labels=[], + reason_codes=[], + evidence=[EvidenceItem(type="lexicon", match="foo", severity=2, lang="en")], + toxicity=0.45, + ) + assert _derive_toxicity(decision, runtime=runtime) == 0.45 + + +def test_toxicity_uses_max_confidence_when_multiple_model_spans() -> None: + runtime = SimpleNamespace( + toxicity_by_action=SimpleNamespace(BLOCK=0.9, REVIEW=0.45, ALLOW=0.05) + ) + decision = Decision( + action="REVIEW", + labels=[], + reason_codes=[], + evidence=[ + EvidenceItem(type="model_span", span="...", confidence=0.2), + EvidenceItem(type="model_span", span="...", confidence=0.8), + ], + toxicity=0.45, + ) + assert _derive_toxicity(decision, runtime=runtime) == 0.59 + + +def test_toxicity_stays_in_0_1_range() -> None: + runtime = SimpleNamespace(toxicity_by_action=SimpleNamespace(BLOCK=1.0, REVIEW=1.0, ALLOW=0.0)) + decision = Decision( + action="BLOCK", + labels=[], + reason_codes=[], + evidence=[EvidenceItem(type="model_span", span="...", confidence=1.0)], + toxicity=1.0, + ) + assert _derive_toxicity(decision, runtime=runtime) == 1.0 diff --git a/tests/test_policy_vector_match.py b/tests/test_policy_vector_match.py index 34d55a1..3d183f2 100644 --- a/tests/test_policy_vector_match.py +++ b/tests/test_policy_vector_match.py @@ -38,14 +38,23 @@ def match(self, _text: str) -> list[LexiconEntry]: "find_hot_trigger_matches", lambda *_args, **_kwargs: [], ) + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + + class _Runtime: + embedding_provider_id = "hash-bow-v1" + + class _Provider: + def embed(self, _text: str, *, timeout_ms: int): # type: ignore[no-untyped-def] + del timeout_ms + return [0.1] * 64 + + embedding_provider = _Provider() + + monkeypatch.setattr(policy, "get_model_runtime", lambda: _Runtime()) monkeypatch.setattr( policy, "find_vector_match", - lambda *_args, **_kwargs: VectorMatch( - entry=vector_entry, - similarity=0.88, - match_id="101", - ), + lambda *_args, **_kwargs: VectorMatch(entry=vector_entry, similarity=0.88, match_id="101"), ) decision = policy.evaluate_text( @@ -84,11 +93,9 @@ def match(self, _text: str) -> list[LexiconEntry]: "find_hot_trigger_matches", lambda *_args, **_kwargs: [], ) - monkeypatch.setattr( - policy, - "find_vector_match", - lambda *_args, **_kwargs: None, - ) + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + monkeypatch.setattr(policy, "get_model_runtime", lambda: object()) + monkeypatch.setattr(policy, "find_vector_match", lambda *_args, **_kwargs: None) decision = policy.evaluate_text( "we should deal with them politically", @@ -122,14 +129,23 @@ def match(self, _text: str) -> list[LexiconEntry]: "find_hot_trigger_matches", lambda *_args, **_kwargs: [], ) + monkeypatch.setenv("SENTINEL_DATABASE_URL", "postgresql://example") + + class _Runtime: + embedding_provider_id = "hash-bow-v1" + + class _Provider: + def embed(self, _text: str, *, timeout_ms: int): # type: ignore[no-untyped-def] + del timeout_ms + return [0.1] * 64 + + embedding_provider = _Provider() + + monkeypatch.setattr(policy, "get_model_runtime", lambda: _Runtime()) monkeypatch.setattr( policy, "find_vector_match", - lambda *_args, **_kwargs: VectorMatch( - entry=vector_entry, - similarity=0.99, - match_id="7", - ), + lambda *_args, **_kwargs: VectorMatch(entry=vector_entry, similarity=0.99, match_id="7"), ) decision = policy.evaluate_text( diff --git a/tests/test_result_cache.py b/tests/test_result_cache.py new file mode 100644 index 0000000..011f331 --- /dev/null +++ b/tests/test_result_cache.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +import sentinel_api.main as main +from sentinel_api.result_cache import make_cache_key +from sentinel_core.models import ModerationContext, ModerationResponse + +client = TestClient(main.app) + + +def test_cache_key_includes_all_provenance_fields() -> None: + base = make_cache_key( + "hello", + policy_version="p1", + lexicon_version="l1", + model_version="m1", + pack_versions={"en": "pack-en-0.1"}, + deployment_stage="supervised", + context=None, + ) + different_policy = make_cache_key( + "hello", + policy_version="p2", + lexicon_version="l1", + model_version="m1", + pack_versions={"en": "pack-en-0.1"}, + deployment_stage="supervised", + context=None, + ) + assert base != different_policy + + +def test_different_context_produces_different_key() -> None: + none_key = make_cache_key( + "hello", + policy_version="p1", + lexicon_version="l1", + model_version="m1", + pack_versions={"en": "pack-en-0.1"}, + deployment_stage="supervised", + context=None, + ) + forward_key = make_cache_key( + "hello", + policy_version="p1", + lexicon_version="l1", + model_version="m1", + pack_versions={"en": "pack-en-0.1"}, + deployment_stage="supervised", + context=ModerationContext(channel="forward"), + ) + assert none_key != forward_key + + +def test_cache_disabled_no_redis_call(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv("SENTINEL_RESULT_CACHE_ENABLED", raising=False) + monkeypatch.setenv("SENTINEL_API_KEY", "k") + monkeypatch.setattr( + main, + "get_cached_result", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("should not call redis")), + ) + response = client.post( + "/v1/moderate", json={"text": "peaceful debate"}, headers={"X-API-Key": "k"} + ) + assert response.status_code == 200 + assert "X-Cache" not in response.headers + + +def test_cache_hit_returns_cached_result(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SENTINEL_RESULT_CACHE_ENABLED", "true") + monkeypatch.setenv("SENTINEL_REDIS_URL", "redis://unused") + monkeypatch.setenv("SENTINEL_API_KEY", "k") + + cached = ModerationResponse.model_validate( + { + "toxicity": 0.0, + "labels": ["BENIGN_POLITICAL_SPEECH"], + "action": "ALLOW", + "reason_codes": ["R_ALLOW_NO_POLICY_MATCH"], + "evidence": [{"type": "model_span", "span": "x", "confidence": 0.9}], + "language_spans": [{"start": 0, "end": 1, "lang": "en"}], + "model_version": "sentinel-multi-v2", + "lexicon_version": "hatelex-v2.1", + "pack_versions": {"en": "pack-en-0.1"}, + "policy_version": "policy-2026.01", + "latency_ms": 1, + } + ) + + monkeypatch.setattr(main, "get_cached_result", lambda *_args, **_kwargs: cached) + monkeypatch.setattr( + main, + "set_cached_result", + lambda *_args, **_kwargs: (_ for _ in ()).throw(AssertionError("should not write on hit")), + ) + + response = client.post("/v1/moderate", json={"text": "hello"}, headers={"X-API-Key": "k"}) + assert response.status_code == 200 + assert response.headers["X-Cache"] == "HIT" + assert response.json()["policy_version"] == "policy-2026.01" + + +def test_cache_miss_sets_cached_result(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("SENTINEL_RESULT_CACHE_ENABLED", "true") + monkeypatch.setenv("SENTINEL_REDIS_URL", "redis://unused") + monkeypatch.setenv("SENTINEL_API_KEY", "k") + monkeypatch.setattr(main, "get_cached_result", lambda *_args, **_kwargs: None) + captured: dict[str, object] = {} + + def _capture_set( + cache_key: str, result: ModerationResponse, redis_url: str, *, ttl: int + ) -> None: + captured["cache_key"] = cache_key + captured["redis_url"] = redis_url + captured["ttl"] = ttl + captured["action"] = result.action + + monkeypatch.setattr(main, "set_cached_result", _capture_set) + + response = client.post( + "/v1/moderate", + json={"text": "We should discuss policy peacefully."}, + headers={"X-API-Key": "k"}, + ) + assert response.status_code == 200 + assert response.headers["X-Cache"] == "MISS" + assert captured["redis_url"] == "redis://unused" + assert captured["ttl"] == 60 + assert captured["action"] in {"ALLOW", "REVIEW", "BLOCK"} diff --git a/tests/test_vector_match_postgres_integration.py b/tests/test_vector_match_postgres_integration.py index 42e51fa..aad6c2d 100644 --- a/tests/test_vector_match_postgres_integration.py +++ b/tests/test_vector_match_postgres_integration.py @@ -5,7 +5,12 @@ import pytest -from sentinel_api.vector_matcher import find_vector_match, reset_vector_match_cache +from sentinel_api.vector_matcher import ( + VECTOR_MODEL, + embed_text, + find_vector_match, + reset_vector_match_cache, +) def _integration_db_url() -> str | None: @@ -35,6 +40,8 @@ def test_pgvector_match_populates_embedding_table_and_returns_candidate( match = find_vector_match( "rigged", lexicon_version="hatelex-v2.1", + query_embedding=embed_text("rigged"), + embedding_model=VECTOR_MODEL, ) assert match is not None assert match.match_id @@ -45,7 +52,10 @@ def test_pgvector_match_populates_embedding_table_and_returns_candidate( psycopg = importlib.import_module("psycopg") with psycopg.connect(db_url) as conn: with conn.cursor() as cur: - cur.execute("SELECT COUNT(1) FROM lexicon_entry_embeddings") + cur.execute( + "SELECT COUNT(1) FROM lexicon_entry_embeddings_v2 WHERE embedding_model = %s", + (VECTOR_MODEL,), + ) row = cur.fetchone() assert row is not None assert int(row[0]) > 0 diff --git a/tests/test_vector_matcher.py b/tests/test_vector_matcher.py index 0c139f3..6a832df 100644 --- a/tests/test_vector_matcher.py +++ b/tests/test_vector_matcher.py @@ -27,6 +27,8 @@ def test_find_vector_match_returns_none_without_database(monkeypatch) -> None: result = vector_matcher.find_vector_match( "election manipulation narrative", lexicon_version="hatelex-v2.1", + query_embedding=vector_matcher.embed_text("election manipulation narrative"), + embedding_model=vector_matcher.VECTOR_MODEL, ) assert result is None @@ -37,6 +39,8 @@ def test_find_vector_match_returns_none_when_disabled(monkeypatch) -> None: result = vector_matcher.find_vector_match( "election manipulation narrative", lexicon_version="hatelex-v2.1", + query_embedding=vector_matcher.embed_text("election manipulation narrative"), + embedding_model=vector_matcher.VECTOR_MODEL, ) assert result is None @@ -52,13 +56,13 @@ def __init__(self) -> None: self._fetchone_result = None def execute(self, query: str, params=None) -> None: - if "LEFT JOIN lexicon_entry_embeddings" in query: + if "LEFT JOIN lexicon_entry_embeddings_v2" in query: self._fetchall_result = [(7, "rigged")] return - if "INSERT INTO lexicon_entry_embeddings" in query: + if "INSERT INTO lexicon_entry_embeddings_v2" in query: state["upserts"] += 1 return - if "JOIN lexicon_entry_embeddings AS emb" in query: + if "JOIN lexicon_entry_embeddings_v2 AS emb" in query: self._fetchone_result = ( 7, "rigged", @@ -106,6 +110,8 @@ def connect(self, _database_url: str) -> _Connection: match = vector_matcher.find_vector_match( "they manipulated election results", lexicon_version="hatelex-v2.1", + query_embedding=vector_matcher.embed_text("they manipulated election results"), + embedding_model=vector_matcher.VECTOR_MODEL, ) assert match is not None @@ -122,10 +128,10 @@ def __init__(self) -> None: self._fetchone_result = None def execute(self, query: str, params=None) -> None: - if "LEFT JOIN lexicon_entry_embeddings" in query: + if "LEFT JOIN lexicon_entry_embeddings_v2" in query: self._fetchall_result = [] return - if "JOIN lexicon_entry_embeddings AS emb" in query: + if "JOIN lexicon_entry_embeddings_v2 AS emb" in query: self._fetchone_result = ( 13, "rigged", @@ -172,6 +178,8 @@ def connect(self, _database_url: str) -> _Connection: result = vector_matcher.find_vector_match( "they manipulated election results", lexicon_version="hatelex-v2.1", + query_embedding=vector_matcher.embed_text("they manipulated election results"), + embedding_model=vector_matcher.VECTOR_MODEL, ) assert result is None assert "vector similarity was non-finite" in caplog.text