diff --git a/.gitignore b/.gitignore index df559c58..c5b3239e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ yarn-debug.log* yarn-error.log* lerna-debug.log* .pnpm-debug.log* +/planning # Diagnostic reports (https://nodejs.org/api/report.html) report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json diff --git a/DEPLOY.md b/DEPLOY.md index 139c9f5b..c74ca383 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -298,7 +298,7 @@ AZURE_OPENAI_GPT4O_MINI_ENDPOINT=your-gpt4o-mini-endpoint # Storage AZURE_STORAGE_CONNECTION_STRING=your-connection-string -AZURE_STORAGE_CONTAINER_NAME=can-sr-storage +STORAGE_CONTAINER_NAME=can-sr-storage # Authentication SECRET_KEY=your-secret-key-change-in-production @@ -306,7 +306,11 @@ ACCESS_TOKEN_EXPIRE_MINUTES=10080 # Databases (configured in docker-compose.yml) MONGODB_URI=mongodb://sr-mongodb-service:27017/mongodb-sr -POSTGRES_URI=postgres://admin:password@cit-pgdb-service:5432/postgres-cits +POSTGRES_MODE=docker +POSTGRES_HOST=pgdb-service +POSTGRES_DATABASE=postgres +POSTGRES_USER=admin +POSTGRES_PASSWORD=password # Databricks (for database search) DATABRICKS_INSTANCE=your-databricks-instance diff --git a/README.md b/README.md index 11a6f505..7273c941 100644 --- a/README.md +++ b/README.md @@ -221,11 +221,53 @@ AZURE_OPENAI_ENDPOINT=your-endpoint AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o # Storage -AZURE_STORAGE_CONNECTION_STRING=your-connection-string +# STORAGE_MODE is strict: local | azure | entra +STORAGE_MODE=local + +# Storage container name +# - local: folder name under LOCAL_STORAGE_BASE_PATH +# - azure/entra: blob container name +STORAGE_CONTAINER_NAME=can-sr-storage + +# local storage +LOCAL_STORAGE_BASE_PATH=uploads + +# azure storage (account name + key) +# STORAGE_MODE=azure +AZURE_STORAGE_ACCOUNT_NAME=youraccount +AZURE_STORAGE_ACCOUNT_KEY=your-key + +# entra storage (Managed Identity / DefaultAzureCredential) +# STORAGE_MODE=entra +AZURE_STORAGE_ACCOUNT_NAME=youraccount # Databases -MONGODB_URI=mongodb://localhost:27017/mongodb-sr -POSTGRES_URI=postgres://admin:password@localhost:5432/postgres-cits +MONGODB_URI=mongodb://sr-mongodb-service:27017/mongodb-sr + +# Postgres configuration +POSTGRES_MODE=docker # docker | local | azure + +# Canonical Postgres connection settings (single set) +# - docker/local: POSTGRES_PASSWORD is required +# - azure: POSTGRES_PASSWORD is ignored (Entra token auth via DefaultAzureCredential) +POSTGRES_HOST=pgdb-service +POSTGRES_DATABASE=postgres +POSTGRES_USER=admin +POSTGRES_PASSWORD=password + +# Local Postgres (developer machine) +# POSTGRES_MODE=local +# POSTGRES_HOST=localhost +# POSTGRES_DATABASE=grep +# POSTGRES_USER=postgres +# POSTGRES_PASSWORD=123 + +# Azure Database for PostgreSQL (Entra auth) +# POSTGRES_MODE=azure +# POSTGRES_HOST=...postgres.database.azure.com +# POSTGRES_DATABASE=grep +# POSTGRES_USER= +# POSTGRES_PASSWORD= # not used in azure mode # Databricks (for database search) DATABRICKS_INSTANCE=your-instance diff --git a/backend/Dockerfile b/backend/Dockerfile index 6a01e65f..8dd10676 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -9,7 +9,7 @@ WORKDIR /app RUN apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ apt-get update && \ - apt-get install -y --no-install-recommends \ + apt-get install -y \ gcc \ g++ \ git \ @@ -17,6 +17,8 @@ RUN apt-get clean && \ curl \ make \ libc6-dev \ + dialog \ + openssh-server \ && apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* @@ -26,7 +28,7 @@ ENV PIP_DEFAULT_TIMEOUT=300 # Install Python dependencies COPY requirements.txt . -RUN pip install -r requirements.txt +RUN pip install -r requirements.txt --no-cache-dir # Copy application code COPY . . @@ -41,10 +43,7 @@ RUN useradd -m -u 1001 appuser && chown -R appuser:appuser /app COPY sshd_config /etc/ssh/ COPY entrypoint.sh /entrypoint.sh -RUN apt-get update \ - && apt-get install -y --no-install-recommends dialog \ - && apt-get install -y --no-install-recommends openssh-server \ - && echo "root:Docker!" | chpasswd \ +RUN echo "root:Docker!" | chpasswd \ && chmod u+x /entrypoint.sh USER root diff --git a/backend/README.md b/backend/README.md index 1af57f14..7a88eca9 100644 --- a/backend/README.md +++ b/backend/README.md @@ -15,7 +15,7 @@ CAN-SR Backend provides a production-ready REST API for managing systematic revi - **PDF Processing** - Full-text extraction using GROBID - **Azure OpenAI Integration** - GPT-4o, GPT-4o-mini, GPT-3.5-turbo for AI features - **JWT Authentication** - Secure user authentication -- **Azure Blob Storage** - Scalable document storage +- **Storage** - Local filesystem or Azure Blob Storage (connection string or Entra) ## Architecture @@ -51,12 +51,54 @@ AZURE_OPENAI_API_KEY=your-azure-openai-api-key AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o -# Azure Storage (Required) -AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=... +# Storage +# STORAGE_MODE is strict: local | azure | entra +STORAGE_MODE=local + +# Storage container name +# - local: folder name under LOCAL_STORAGE_BASE_PATH +# - azure/entra: blob container name +STORAGE_CONTAINER_NAME=can-sr-storage + +# local storage +LOCAL_STORAGE_BASE_PATH=uploads + +# azure storage (account name + key) +# STORAGE_MODE=azure +AZURE_STORAGE_ACCOUNT_NAME=youraccount +AZURE_STORAGE_ACCOUNT_KEY=your-key + +# entra storage (Managed Identity / DefaultAzureCredential) +# STORAGE_MODE=entra +AZURE_STORAGE_ACCOUNT_NAME=youraccount # Databases (Docker defaults - change for production) -MONGODB_URI=mongodb://sr-mongodb-service:27017/mongodb-sr -POSTGRES_URI=postgres://admin:password@cit-pgdb-service:5432/postgres-cits + + +# Postgres configuration +POSTGRES_MODE=docker # docker | local | azure + +# Canonical Postgres connection settings (single set) +# - docker/local: POSTGRES_PASSWORD is required +# - azure: POSTGRES_PASSWORD is ignored (Entra token auth via DefaultAzureCredential) +POSTGRES_HOST=pgdb-service +POSTGRES_DATABASE=postgres +POSTGRES_USER=admin +POSTGRES_PASSWORD=password + +# Local Postgres (developer machine) +# POSTGRES_MODE=local +# POSTGRES_HOST=localhost +# POSTGRES_DATABASE=grep +# POSTGRES_USER=postgres +# POSTGRES_PASSWORD=123 + +# Azure Database for PostgreSQL (Entra auth) +# POSTGRES_MODE=azure +# POSTGRES_HOST= +# POSTGRES_DATABASE= +# POSTGRES_USER= +# POSTGRES_PASSWORD= # not used in azure mode # GROBID Service GROBID_SERVICE_URL=http://grobid-service:8070 @@ -215,14 +257,21 @@ docker compose restart api | `AZURE_OPENAI_API_KEY` | Azure OpenAI API key | `abc123...` | | `AZURE_OPENAI_ENDPOINT` | Azure OpenAI endpoint URL | `https://your-resource.openai.azure.com` | | `AZURE_OPENAI_DEPLOYMENT_NAME` | Model deployment name | `gpt-4o` | -| `AZURE_STORAGE_CONNECTION_STRING` | Azure Blob Storage connection | `DefaultEndpointsProtocol=https;...` | +| `STORAGE_MODE` | Storage backend selector | `local` | +| `LOCAL_STORAGE_BASE_PATH` | Local storage base path (when local) | `uploads` | +| `AZURE_STORAGE_CONNECTION_STRING` | Azure Blob (when STORAGE_MODE=azure) | `DefaultEndpointsProtocol=https;...` | +| `ENTRA_AZURE_STORAGE_ACCOUNT_NAME` | Azure account (when STORAGE_MODE=entra) | `mystorageacct` | | `SECRET_KEY` | JWT token signing key | `your-secure-secret-key` | ### Optional Variables | Variable | Description | Default | |----------|-------------|---------| | `MONGODB_URI` | MongoDB connection string | `mongodb://sr-mongodb-service:27017/mongodb-sr` | -| `POSTGRES_URI` | PostgreSQL connection string | `postgres://admin:password@cit-pgdb-service:5432/postgres-cits` | +| `POSTGRES_MODE` | Postgres connection mode: `docker` \| `local` \| `azure` | `docker` | +| `POSTGRES_HOST` | Postgres host (docker: service name; local: localhost; azure: FQDN) | `pgdb-service` | +| `POSTGRES_DATABASE` | Postgres database name | `postgres` | +| `POSTGRES_USER` | Postgres user (azure: Entra UPN or role) | `admin` | +| `POSTGRES_PASSWORD` | Postgres password (ignored when POSTGRES_MODE=azure) | `password` | | `GROBID_SERVICE_URL` | GROBID service URL | `http://grobid-service:8070` | | `DATABRICKS_INSTANCE` | Databricks workspace URL | - | | `DATABRICKS_TOKEN` | Databricks access token | - | diff --git a/backend/api/citations/router.py b/backend/api/citations/router.py index 247a2859..9c1e0ff9 100644 --- a/backend/api/citations/router.py +++ b/backend/api/citations/router.py @@ -6,7 +6,7 @@ - All endpoints require a Systematic Review (sr_id) that the user is a member of. - Uploading a CSV will: - Parse the CSV - - Create a new Postgres *table* in the shared database (POSTGRES_URI) + - Create a new Postgres *table* in the shared database - Insert citation rows from the CSV into that table - Save the table name + connection string into the Systematic Review record @@ -37,32 +37,25 @@ router = APIRouter() -def _get_db_conn_str() -> Optional[str]: +def _is_postgres_configured() -> bool: """ - Get database connection string for PostgreSQL. - - If POSTGRES_URI is set, returns it directly (local development). - If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER), - returns None to signal that connect_postgres() should use Entra ID authentication. + Check if PostgreSQL is configured via the POSTGRES_MODE profile. """ - if settings.POSTGRES_URI: - return settings.POSTGRES_URI - - # If Entra ID config is available, return None to let connect_postgres use token auth - if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER: - return None - - # No configuration available - return None, let downstream handle the error - return None + try: + prof = settings.postgres_profile() + except Exception: + return False + if not (prof.get("database") and prof.get("user")): + return False -def _is_postgres_configured() -> bool: - """ - Check if PostgreSQL is configured via Entra ID env vars or connection string. - """ - has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER - has_uri_config = settings.POSTGRES_URI - return bool(has_entra_config or has_uri_config) + if prof.get("mode") in ("local", "docker") and not prof.get("password"): + return False + + if prof.get("mode") == "azure" and not prof.get("host"): + return False + + return True class UploadResult(BaseModel): @@ -84,8 +77,8 @@ def _parse_dsn(dsn: str) -> Dict[str, str]: return parse_dsn(dsn) except Exception: return {} -def _create_table_and_insert_sync(db_conn_str: str, table_name: str, columns: List[str], rows: List[Dict[str, Any]]) -> int: - return cits_dp_service.create_table_and_insert_sync(db_conn_str, table_name, columns, rows) +def _create_table_and_insert_sync(table_name: str, columns: List[str], rows: List[Dict[str, Any]]) -> int: + return cits_dp_service.create_table_and_insert_sync(table_name, columns, rows) @router.post("/{sr_id}/upload-csv", response_model=UploadResult) @@ -98,25 +91,23 @@ async def upload_screening_csv( Upload a CSV of citations for title/abstract screening and create a dedicated Postgres table. Requirements: - - POSTGRES_URI must be configured. + - Postgres must be configured via POSTGRES_MODE and POSTGRES_* env vars. - The SR must exist and the user must be a member of the SR (or owner). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, _ = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") - # Check admin DSN (use centralized settings) - need either Entra ID config or POSTGRES_URI + # Check admin config (use centralized settings) if not _is_postgres_configured(): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Postgres not configured. Set POSTGRES_HOST/DATABASE/USER for Entra ID auth, or POSTGRES_URI for local dev.", + detail="Postgres not configured. Set POSTGRES_MODE and POSTGRES_* env vars.", ) - admin_dsn = _get_db_conn_str() # Read CSV content include_columns = None @@ -141,14 +132,14 @@ async def upload_screening_csv( try: old = (sr.get("screening_db") or {}).get("table_name") if old: - await run_in_threadpool(cits_dp_service.drop_table, admin_dsn, old) + await run_in_threadpool(cits_dp_service.drop_table, old) except Exception: # best-effort only pass # Create table and insert rows in threadpool try: - inserted = await run_in_threadpool(_create_table_and_insert_sync, admin_dsn, table_name, include_columns, normalized_rows) + inserted = await run_in_threadpool(_create_table_and_insert_sync, table_name, include_columns, normalized_rows) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -158,7 +149,6 @@ async def upload_screening_csv( try: screening_info = { "screening_db": { - "connection_string": admin_dsn, "table_name": table_name, "created_at": datetime.utcnow().isoformat(), "rows": inserted, @@ -168,7 +158,6 @@ async def upload_screening_csv( # Update SR document with screening DB info using PostgreSQL await run_in_threadpool( srdb_service.update_screening_db_info, - _get_db_conn_str(), sr_id, screening_info["screening_db"] ) @@ -198,9 +187,8 @@ async def list_citation_ids( Returns a simple list of integers (the 'id' primary key from the citations table). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -212,7 +200,7 @@ async def list_citation_ids( table_name = (screening or {}).get("table_name") or "citations" try: - ids = await run_in_threadpool(cits_dp_service.list_citation_ids, db_conn, filter_step, table_name) + ids = await run_in_threadpool(cits_dp_service.list_citation_ids, filter_step, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -237,9 +225,8 @@ async def get_citation_by_id( Returns: a JSON object representing the citation row (keys are DB column names). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -248,7 +235,7 @@ async def get_citation_by_id( table_name = (screening or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -290,9 +277,8 @@ async def build_combined_citation( the format ": \\n" for each included column, in the order provided. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -304,7 +290,7 @@ async def build_combined_citation( table_name = (screening or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -345,9 +331,8 @@ async def upload_citation_fulltext( to the storage path (container/blob). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -368,7 +353,7 @@ async def upload_citation_fulltext( table_name = (screening or {}).get("table_name") or "citations" try: - existing_row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + existing_row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -403,7 +388,7 @@ async def upload_citation_fulltext( # Update citation row in Postgres try: - updated = await run_in_threadpool(cits_dp_service.attach_fulltext, db_conn, citation_id, storage_path, content, table_name) + updated = await run_in_threadpool(cits_dp_service.attach_fulltext, citation_id, storage_path, content, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -441,9 +426,8 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An - Caller must be the SR owner. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -456,15 +440,13 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An if not screening: return {"status": "no_screening_db", "message": "No screening table configured for this SR", "deleted_table": False, "deleted_files": 0} - db_conn = None - table_name = screening.get("table_name") if not table_name: return {"status": "no_screening_db", "message": "Incomplete screening DB metadata", "deleted_table": False, "deleted_files": 0} # 1) collect fulltext URLs from the screening DB try: - urls = await run_in_threadpool(cits_dp_service.list_fulltext_urls, db_conn, table_name) + urls = await run_in_threadpool(cits_dp_service.list_fulltext_urls, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -519,12 +501,13 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An parsed_ok = False if not parsed_ok: - # fallback: try direct deletion via blob client (best-effort) + # fallback: delete by storage path (works for both azure/local) if storage_service: try: - blob_client = storage_service.blob_service_client.get_blob_client(container=container, blob=blob) - blob_client.delete_blob() + await storage_service.delete_by_path(f"{container}/{blob}") deleted_files += 1 + except FileNotFoundError: + failed_files += 1 except Exception: failed_files += 1 else: @@ -535,7 +518,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An # 3) drop the screening table try: - await run_in_threadpool(cits_dp_service.drop_table, db_conn, table_name) + await run_in_threadpool(cits_dp_service.drop_table, table_name) table_dropped = True except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) @@ -546,7 +529,6 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An try: await run_in_threadpool( srdb_service.clear_screening_db_info, - _get_db_conn_str(), sr_id ) except Exception: @@ -582,10 +564,9 @@ async def export_citations_csv( Content-Disposition. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check( - sr_id, current_user, db_conn_str, srdb_service + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service ) except HTTPException: raise @@ -595,7 +576,7 @@ async def export_citations_csv( detail=f"Failed to load systematic review: {e}", ) - if not screening or not db_conn: + if not screening: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review", @@ -604,7 +585,7 @@ async def export_citations_csv( table_name = (screening or {}).get("table_name") or "citations" try: - csv_bytes = await run_in_threadpool(cits_dp_service.dump_citations_csv, db_conn, table_name) + csv_bytes = await run_in_threadpool(cits_dp_service.dump_citations_csv, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: diff --git a/backend/api/core/cit_utils.py b/backend/api/core/cit_utils.py index 12442fa8..1e33ce13 100644 --- a/backend/api/core/cit_utils.py +++ b/backend/api/core/cit_utils.py @@ -16,19 +16,33 @@ from .config import settings -def _is_postgres_configured(db_conn_str: Optional[str] = None) -> bool: +def _is_postgres_configured() -> bool: """ - Check if PostgreSQL is configured via Entra ID env vars or connection string. + Check if PostgreSQL is configured via the POSTGRES_MODE profile. """ - has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER - has_uri_config = db_conn_str or settings.POSTGRES_URI - return bool(has_entra_config or has_uri_config) + try: + prof = settings.postgres_profile() + except Exception: + return False + + # minimal requirements + if not (prof.get("database") and prof.get("user")): + return False + + # password required for local/docker + if prof.get("mode") in ("local", "docker") and not prof.get("password"): + return False + + # azure requires host + if prof.get("mode") == "azure" and not prof.get("host"): + return False + + return True async def load_sr_and_check( sr_id: str, current_user: Dict[str, Any], - db_conn_str: Optional[str], srdb_service, require_screening: bool = True, require_visible: bool = True, @@ -39,20 +53,19 @@ async def load_sr_and_check( Args: sr_id: SR id string current_user: current user dict (must contain "id" and "email") - db_conn_str: PostgreSQL connection string (can be None if using Entra ID auth) srdb_service: SR DB service instance (must implement get_systematic_review and user_has_sr_permission) require_screening: if True, also ensure the SR has a configured screening_db and return its connection string require_visible: if True, require the SR 'visible' flag to be True; set False for endpoints like hard-delete Returns: - (sr_doc, screening_obj or None, db_conn_string or None) + (sr_doc, screening_obj or None) Raises HTTPException with appropriate status codes on failure so routers can just propagate. """ # fetch SR try: - sr = await run_in_threadpool(srdb_service.get_systematic_review, db_conn_str, sr_id, not require_visible) + sr = await run_in_threadpool(srdb_service.get_systematic_review, sr_id, not require_visible) except HTTPException: raise except Exception as e: @@ -67,7 +80,7 @@ async def load_sr_and_check( # permission check (user must be member or owner) user_id = current_user.get("email") try: - has_perm = await run_in_threadpool(srdb_service.user_has_sr_permission, db_conn_str, sr_id, user_id) + has_perm = await run_in_threadpool(srdb_service.user_has_sr_permission, sr_id, user_id) except HTTPException: raise except Exception as e: @@ -77,10 +90,9 @@ async def load_sr_and_check( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view/modify this systematic review") screening = sr.get("screening_db") if isinstance(sr, dict) else None - db_conn = None if require_screening: if not screening: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review") - db_conn = None - return sr, screening, db_conn + + return sr, screening diff --git a/backend/api/core/config.py b/backend/api/core/config.py index 3fb7d32a..37796fb1 100644 --- a/backend/api/core/config.py +++ b/backend/api/core/config.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import List, Optional from pydantic import Field, field_validator from pydantic_settings import BaseSettings @@ -21,22 +22,30 @@ class Settings(BaseSettings): DESCRIPTION: str = os.getenv( "DESCRIPTION", "AI-powered systematic review platform for Government of Canada" ) - IS_DEPLOYED: bool = os.getenv("IS_DEPLOYED") + IS_DEPLOYED: bool = os.getenv("IS_DEPLOYED", "false").lower() == "true" # CORS CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", "*") # Storage settings - STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "azure") - AZURE_STORAGE_ACCOUNT_NAME: Optional[str] = os.getenv( - "AZURE_STORAGE_ACCOUNT_NAME" - ) - AZURE_STORAGE_CONNECTION_STRING: Optional[str] = os.getenv( - "AZURE_STORAGE_CONNECTION_STRING" - ) - AZURE_STORAGE_CONTAINER_NAME: str = os.getenv( - "AZURE_STORAGE_CONTAINER_NAME", "can-sr-storage" - ) + # Storage selection (strict): local | azure | entra + STORAGE_MODE: str = os.getenv("STORAGE_MODE", "azure").lower().strip() + # Canonical storage container name used across all storage types. + # - local: folder name under LOCAL_STORAGE_BASE_PATH + # - azure/entra: blob container name + STORAGE_CONTAINER_NAME: str = os.getenv("STORAGE_CONTAINER_NAME", "can-sr-storage") + # Azure Storage + # - STORAGE_MODE=azure requires account name + account key + # - STORAGE_MODE=entra requires only account name (uses DefaultAzureCredential) + AZURE_STORAGE_ACCOUNT_NAME: Optional[str] = os.getenv("AZURE_STORAGE_ACCOUNT_NAME") + AZURE_STORAGE_ACCOUNT_KEY: Optional[str] = os.getenv("AZURE_STORAGE_ACCOUNT_KEY") + + # Local storage settings (used when STORAGE_MODE=local) + # In docker, default path is backed by the compose volume: ./uploads:/app/uploads + # Default to a relative directory so it works both locally and in docker: + # - locally: /backend/uploads + # - in docker: /app/uploads + LOCAL_STORAGE_BASE_PATH: str = os.getenv("LOCAL_STORAGE_BASE_PATH", "uploads") # File upload settings MAX_FILE_SIZE: int = Field(default=52428800) # 50MB in bytes @@ -60,40 +69,11 @@ def convert_max_file_size(cls, v): # Azure OpenAI settings (Primary - GPT-4o) AZURE_OPENAI_API_KEY: Optional[str] = os.getenv("AZURE_OPENAI_API_KEY") AZURE_OPENAI_ENDPOINT: Optional[str] = os.getenv("AZURE_OPENAI_ENDPOINT") - AZURE_OPENAI_API_VERSION: str = os.getenv( - "AZURE_OPENAI_API_VERSION", "2025-01-01-preview" - ) - AZURE_OPENAI_DEPLOYMENT_NAME: str = os.getenv( - "AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o" - ) - - # GPT-4.1-mini configuration - AZURE_OPENAI_GPT41_MINI_API_KEY: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT41_MINI_API_KEY" - ) - AZURE_OPENAI_GPT41_MINI_ENDPOINT: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT41_MINI_ENDPOINT" - ) - AZURE_OPENAI_GPT41_MINI_DEPLOYMENT: str = os.getenv( - "AZURE_OPENAI_GPT41_MINI_DEPLOYMENT", "gpt-4.1-mini" - ) - AZURE_OPENAI_GPT41_MINI_API_VERSION: str = os.getenv( - "AZURE_OPENAI_GPT41_MINI_API_VERSION", "2025-01-01-preview" - ) - - # GPT-5-mini configuration - AZURE_OPENAI_GPT5_MINI_API_KEY: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT5_MINI_API_KEY" - ) - AZURE_OPENAI_GPT5_MINI_ENDPOINT: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT5_MINI_ENDPOINT" - ) - AZURE_OPENAI_GPT5_MINI_DEPLOYMENT: str = os.getenv( - "AZURE_OPENAI_GPT5_MINI_DEPLOYMENT", "gpt-5-mini" - ) - AZURE_OPENAI_GPT5_MINI_API_VERSION: str = os.getenv( - "AZURE_OPENAI_GPT5_MINI_API_VERSION", "2025-08-07" - ) + # Azure OpenAI auth selection + # key -> uses AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_API_KEY + # entra -> uses AZURE_OPENAI_ENDPOINT + DefaultAzureCredential + # Backwards/alternate env var: OPENAI_TYPE + AZURE_OPENAI_MODE: str = os.getenv("AZURE_OPENAI_MODE", "key").lower().strip() # Default model to use DEFAULT_CHAT_MODEL: str = os.getenv("DEFAULT_CHAT_MODEL", "gpt-5-mini") @@ -110,37 +90,80 @@ def convert_max_file_size(cls, v): ) DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" - # Database and external system environment variables - # Postgres settings for Entra ID authentication + # ------------------------------------------------------------------------- + # Postgres configuration + # ------------------------------------------------------------------------- + # Select primary Postgres mode. + # docker/local use password auth; azure uses Entra token auth. + POSTGRES_MODE: str = os.getenv("POSTGRES_MODE", "docker").lower().strip() # docker|local|azure + + # Canonical Postgres connection settings (single profile; values vary by environment) POSTGRES_HOST: Optional[str] = os.getenv("POSTGRES_HOST") POSTGRES_DATABASE: Optional[str] = os.getenv("POSTGRES_DATABASE") - POSTGRES_USER: Optional[str] = os.getenv("POSTGRES_USER") # Entra ID user (e.g., user@tenant.onmicrosoft.com) - POSTGRES_PORT: int = int(os.getenv("POSTGRES_PORT", "5432")) - POSTGRES_SSL_MODE: Optional[str] = os.getenv("POSTGRES_SSL_MODE") + POSTGRES_USER: Optional[str] = os.getenv("POSTGRES_USER") POSTGRES_PASSWORD: Optional[str] = os.getenv("POSTGRES_PASSWORD") - AZURE_DB: bool = os.getenv("AZURE_DB", "false").lower() == "true" - # Legacy: Postgres DSN used for systematic reviews and screening databases (fallback) + + # Optional overrides + POSTGRES_PORT: int = int(os.getenv("POSTGRES_PORT", "5432")) + POSTGRES_SSL_MODE: str = os.getenv("POSTGRES_SSL_MODE", "require") + + # Deprecated (will be removed): legacy Postgres DSN POSTGRES_URI: Optional[str] = os.getenv("POSTGRES_URI") + def postgres_profile(self, mode: Optional[str] = None) -> dict: + """Return resolved Postgres connection settings for a specific mode. + + The application uses a single set of environment variables: + POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER, POSTGRES_PASSWORD. + + POSTGRES_MODE controls *how* authentication is performed: + - docker/local: password auth (POSTGRES_PASSWORD required) + - azure: Entra token auth (password ignored) + sslmode=require + """ + m = (mode or self.POSTGRES_MODE or "").lower().strip() + if m not in {"docker", "local", "azure"}: + raise ValueError("POSTGRES_MODE must be one of: docker, local, azure") + + # Provide sensible defaults for host depending on mode. + default_host = "pgdb-service" if m == "docker" else "localhost" if m == "local" else None + + prof = { + "mode": m, + "host": self.POSTGRES_HOST or default_host, + "database": self.POSTGRES_DATABASE, + "user": self.POSTGRES_USER, + # For azure we intentionally ignore password (token auth) + "password": None if m == "azure" else self.POSTGRES_PASSWORD, + "port": self.POSTGRES_PORT, + "sslmode": (self.POSTGRES_SSL_MODE or "require") if m == "azure" else None, + } + + return prof + + def has_local_fallback(self) -> bool: + # Deprecated: legacy behavior. Kept for compatibility with older code paths. + return False + # Databricks settings - DATABRICKS_INSTANCE: str = os.getenv("DATABRICKS_INSTANCE") - DATABRICKS_TOKEN: str = os.getenv("DATABRICKS_TOKEN") - JOB_ID_EUROPEPMC: str = os.getenv("JOB_ID_EUROPEPMC") - JOB_ID_PUBMED: str = os.getenv("JOB_ID_PUBMED") - JOB_ID_SCOPUS: str = os.getenv("JOB_ID_SCOPUS") + DATABRICKS_INSTANCE: str = os.getenv("DATABRICKS_INSTANCE", "") + DATABRICKS_TOKEN: str = os.getenv("DATABRICKS_TOKEN", "") + JOB_ID_EUROPEPMC: str = os.getenv("JOB_ID_EUROPEPMC", "") + JOB_ID_PUBMED: str = os.getenv("JOB_ID_PUBMED", "") + JOB_ID_SCOPUS: str = os.getenv("JOB_ID_SCOPUS", "") # OAuth - OAUTH_CLIENT_ID: str = os.getenv("OAUTH_CLIENT_ID") - OAUTH_CLIENT_SECRET: str = os.getenv("OAUTH_CLIENT_SECRET") - REDIRECT_URI: str = os.getenv("REDIRECT_URI") - SSO_LOGIN_URL: str = os.getenv("SSO_LOGIN_URL") + OAUTH_CLIENT_ID: str = os.getenv("OAUTH_CLIENT_ID", "") + OAUTH_CLIENT_SECRET: str = os.getenv("OAUTH_CLIENT_SECRET", "") + REDIRECT_URI: str = os.getenv("REDIRECT_URI", "") + SSO_LOGIN_URL: str = os.getenv("SSO_LOGIN_URL", "") # Entra USE_ENTRA_AUTH: bool = os.getenv("USE_ENTRA_AUTH", "false").lower() == "true" class Config: case_sensitive = True - env_file = ".env" + # Resolve to backend/.env regardless of current working directory. + env_file = str(Path(__file__).resolve().parents[2] / ".env") extra = "ignore" # Allow extra environment variables diff --git a/backend/api/core/security.py b/backend/api/core/security.py index 6b584dc2..508cfcb0 100644 --- a/backend/api/core/security.py +++ b/backend/api/core/security.py @@ -60,7 +60,7 @@ async def create_user(user_data: UserCreate) -> Optional[UserRead]: if not user_db_service: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="User registration is not available without Azure Storage configuration", + detail="User registration is not available without configured storage", ) return await user_db_service.create_user(user_data) @@ -71,7 +71,7 @@ async def update_user(user_id: str, user_in: UserUpdate) -> Optional[UserRead]: if not user_db_service: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="User operations are not available without Azure Storage configuration", + detail="User operations are not available without configured storage", ) update_data = user_in.model_dump(exclude_unset=True) diff --git a/backend/api/extract/router.py b/backend/api/extract/router.py index a2b1ff91..7b5ea93a 100644 --- a/backend/api/extract/router.py +++ b/backend/api/extract/router.py @@ -83,9 +83,8 @@ async def extract_parameter_endpoint( derived from the parameter name (prefixed with 'llm_param_'). """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -98,7 +97,7 @@ async def extract_parameter_endpoint( row = None if not fulltext: try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -229,7 +228,7 @@ def _extract_json_object(text: str) -> Optional[str]: col_name = snake_case_param(payload.parameter_name) try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, stored, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -256,9 +255,8 @@ async def human_extract_parameter( and does not call any LLM. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -268,7 +266,7 @@ async def human_extract_parameter( # Ensure citation exists (we won't require full_text for human input but check row presence) try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -323,7 +321,7 @@ async def human_extract_parameter( col_name = f"human_param_{core}" if core else "human_param_param" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, stored, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -347,9 +345,8 @@ async def extract_fulltext_from_storage( under column "fulltext", and return the generated fulltext_str. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -359,7 +356,7 @@ async def extract_fulltext_from_storage( # fetch citation row try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -373,16 +370,14 @@ async def extract_fulltext_from_storage( if not storage_path: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No fulltext storage path found on citation row") - if "/" not in storage_path: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognized storage path format") - - container, blob = storage_path.split("/", 1) - try: - blob_client = storage_service.blob_service_client.get_blob_client(container=container, blob=blob) - content = blob_client.download_blob().readall() + content, _filename = await storage_service.get_bytes_by_path(storage_path) + except FileNotFoundError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Fulltext file not found in storage") + except ValueError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognized storage path format") except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download blob: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download from storage: {e}") # If the citation row already contains an extracted full text in the "fulltext" column, # only use it if the stored md5 matches the pdf we just downloaded. @@ -430,10 +425,10 @@ async def extract_fulltext_from_storage( # persist full_text_str and coordinates/pages into citation row try: - updated1 = await run_in_threadpool(cits_dp_service.update_text_column, db_conn, citation_id, "fulltext", full_text_str, table_name) - updated2 = await run_in_threadpool(cits_dp_service.update_text_column, db_conn, citation_id, "fulltext_md5", current_md5, table_name) - updated3 = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, "fulltext_coords", annotations, table_name) - updated4 = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, "fulltext_pages", pages, table_name) + updated1 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext", full_text_str, table_name) + updated2 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext_md5", current_md5, table_name) + updated3 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_coords", annotations, table_name) + updated4 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_pages", pages, table_name) updated = updated1 or updated2 or updated3 or updated4 except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) diff --git a/backend/api/files/router.py b/backend/api/files/router.py index 107d565e..b3c769fe 100644 --- a/backend/api/files/router.py +++ b/backend/api/files/router.py @@ -5,7 +5,6 @@ import os import logging from datetime import datetime, timezone -from azure.core.exceptions import ResourceNotFoundError from fastapi import ( APIRouter, @@ -206,9 +205,10 @@ async def download_by_path( path: str, current_user: Dict[str, Any] = Depends(get_current_active_user), ): - """ - Stream a blob directly by storage path in the form 'container/blob_path' - Example storage_path: "container/users/123/docid_filename.pdf" + """Download directly by storage path in the form 'container/blob_path'. + + This works for both Azure and local storage backends. + Example: "users/users/123/documents/_.pdf" """ try: if not storage_service: @@ -217,29 +217,14 @@ async def download_by_path( detail="Storage service not available", ) - # Basic validation: expect "container/blob" - if not path or "/" not in path: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid storage path" - ) - - container, blob = path.split("/", 1) - logger.info(f"Downloading blob: container={container}, blob={blob}") - try: - blob_client = storage_service.blob_service_client.get_blob_client( - container=container, blob=blob - ) - stream = blob_client.download_blob() - content = stream.readall() - except ResourceNotFoundError: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Blob not found" - ) - - filename = os.path.basename(blob) or "download" + content, filename = await storage_service.get_bytes_by_path(path) + except FileNotFoundError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + except ValueError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path") + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download: {e}") def gen(): yield content diff --git a/backend/api/screen/router.py b/backend/api/screen/router.py index 20fc9986..76291f7e 100644 --- a/backend/api/screen/router.py +++ b/backend/api/screen/router.py @@ -79,9 +79,8 @@ async def classify_citation( The 'selected' field in the returned JSON is validated against the provided `options`. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -91,7 +90,7 @@ async def classify_citation( # Load citation row (needed for l2 fulltext and for building citation_text) try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -196,7 +195,7 @@ async def classify_citation( col_name = snake_case_column(payload.question) try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, classification_json, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -205,7 +204,7 @@ async def classify_citation( if not updated: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") - await update_inclusion_decision(sr, citation_id, payload.screening_step, "llm", db_conn) + await update_inclusion_decision(sr, citation_id, payload.screening_step, "llm") return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json} @@ -223,9 +222,8 @@ async def human_classify_citation( The column name is prefixed with 'human_' to distinguish from automated classifications. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -235,7 +233,7 @@ async def human_classify_citation( # Ensure citation exists and optionally build combined citation text try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -263,7 +261,7 @@ async def human_classify_citation( col_name = f"human_{col_core}" if col_core else f"human_col" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, classification_json, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -272,7 +270,7 @@ async def human_classify_citation( if not updated: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") - await update_inclusion_decision(sr, citation_id, payload.screening_step, "human", db_conn) + await update_inclusion_decision(sr, citation_id, payload.screening_step, "human") return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json} @@ -281,13 +279,11 @@ async def update_inclusion_decision( citation_id: int, screening_step: str, decision_maker: str, - db_conn: str, - ): table_name = (sr.get("screening_db") or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -312,7 +308,7 @@ async def update_inclusion_decision( col_name = f"{decision_maker}_{screening_step}_decision" try: - updated = await run_in_threadpool(cits_dp_service.update_text_column, db_conn, citation_id, col_name, decision, table_name) + updated = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, col_name, decision, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: diff --git a/backend/api/services/azure_openai_client.py b/backend/api/services/azure_openai_client.py index dd89e685..e6393448 100644 --- a/backend/api/services/azure_openai_client.py +++ b/backend/api/services/azure_openai_client.py @@ -1,12 +1,41 @@ -"""Azure OpenAI client service for chat completions""" +"""backend.api.services.azure_openai_client +Azure OpenAI client service for chat completions. + +Supports: +* API key auth (AZURE_OPENAI_MODE=key) +* Entra/managed identity auth (AZURE_OPENAI_MODE=entra) + +Model catalog: +Loaded from YAML at runtime: + /app/configs/models.yaml + +The YAML is a mapping of UI/display keys to {deployment, api_version}, e.g.: + +GPT-5-Mini: + deployment: gpt-5-mini + api_version: 2025-04-01-preview + +DEFAULT_CHAT_MODEL must be one of those keys. +""" + +from __future__ import annotations + +import json +import logging import time -from typing import Dict, List, Any, Optional +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import yaml + from azure.identity import DefaultAzureCredential, get_bearer_token_provider from openai import AzureOpenAI from ..core.config import settings +logger = logging.getLogger(__name__) + # Token cache TTL in seconds (9 minutes) TOKEN_CACHE_TTL = 9 * 60 @@ -32,66 +61,153 @@ class AzureOpenAIClient: """Client for Azure OpenAI chat completions""" def __init__(self): + self._config_error: Optional[str] = None self.default_model = settings.DEFAULT_CHAT_MODEL - # Create token provider for Azure OpenAI using DefaultAzureCredential - # Wrapped with caching to avoid fetching a new token on every request - if not settings.AZURE_OPENAI_API_KEY and not settings.USE_ENTRA_AUTH: - raise ValueError("Azure OpenAI API key or Entra auth must be configured") + self._auth_type = self._resolve_auth_type() + self._endpoint = self._resolve_endpoint() + self._api_key = settings.AZURE_OPENAI_API_KEY - if settings.USE_ENTRA_AUTH: - self._credential = DefaultAzureCredential() - self._token_provider = CachedTokenProvider( - get_bearer_token_provider( - self._credential, "https://cognitiveservices.azure.com/.default" + self._token_provider: Optional[CachedTokenProvider] = None + if self._auth_type == "entra": + if not DefaultAzureCredential or not get_bearer_token_provider: + self._config_error = ( + "AZURE_OPENAI_MODE=entra requires azure-identity to be installed" ) - ) + else: + # Create token provider for Azure OpenAI using DefaultAzureCredential + # Wrapped with caching to avoid fetching a new token on every request + credential = DefaultAzureCredential() + self._token_provider = CachedTokenProvider( + get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" + ) + ) + + self.model_configs = self._load_model_configs() + self.default_model = self._resolve_default_model(self.default_model) + + # Cache official clients by (endpoint, api_version, auth_type) + self._official_clients: Dict[Tuple[str, str, str], AzureOpenAI] = {} + + + # --------------------------------------------------------------------- + # Configuration + # --------------------------------------------------------------------- + + + @staticmethod + def _resolve_auth_type() -> str: + """Return key|entra. + + New config: AZURE_OPENAI_MODE + Legacy config: USE_ENTRA_AUTH + """ + t = (getattr(settings, "AZURE_OPENAI_MODE", None) or "").lower().strip() + if t in {"key", "entra"}: + return t + # Legacy fallback + if getattr(settings, "USE_ENTRA_AUTH", False): + return "entra" + return "key" + + @staticmethod + def _resolve_endpoint() -> Optional[str]: + return settings.AZURE_OPENAI_ENDPOINT + + def _strip_outer_quotes(self, s: str) -> str: + s = s.strip() + if len(s) >= 2 and ((s[0] == s[-1] == '"') or (s[0] == s[-1] == "'")): + return s[1:-1] + return s + + def _load_models_yaml(self) -> Dict[str, Any]: + """Load model catalog from /app/configs/models.yaml. + + This file is expected to be mounted in docker-compose so changes can be + applied without rebuilding the image. + """ + path = Path("configs/models.yaml") + if not path.exists(): + logger.warning("Azure OpenAI model catalog not found at %s", path) + return {} + try: + data = yaml.safe_load(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + logger.warning("Invalid models.yaml format (expected mapping): %s", type(data)) + return {} + return data + except Exception as e: + logger.exception("Failed to load Azure OpenAI model catalog from %s: %s", path, e) + return {} + + def _load_model_configs(self) -> Dict[str, Dict[str, str]]: + """Build model configs keyed by UI/display name.""" + models = self._load_models_yaml() + cfg: Dict[str, Dict[str, str]] = {} + for display_name, meta in models.items(): + if not isinstance(meta, dict): + continue + deployment = meta.get("deployment") + api_version = meta.get("api_version") + if not deployment or not api_version: + continue + cfg[str(display_name)] = { + "endpoint": self._endpoint or "", + "deployment": str(deployment), + "api_version": str(api_version), + } + + if cfg: + return cfg - self.model_configs = { - "gpt-4.1-mini": { - "endpoint": settings.AZURE_OPENAI_GPT41_MINI_ENDPOINT - or settings.AZURE_OPENAI_ENDPOINT, - "deployment": settings.AZURE_OPENAI_GPT41_MINI_DEPLOYMENT, - "api_version": settings.AZURE_OPENAI_GPT41_MINI_API_VERSION, - }, - "gpt-5-mini": { - "endpoint": settings.AZURE_OPENAI_GPT5_MINI_ENDPOINT - or settings.AZURE_OPENAI_ENDPOINT, - "deployment": settings.AZURE_OPENAI_GPT5_MINI_DEPLOYMENT, - "api_version": settings.AZURE_OPENAI_GPT5_MINI_API_VERSION, - }, - } - - self._official_clients: Dict[str, AzureOpenAI] = {} + return {} + + def _resolve_default_model(self, desired: str) -> str: + if desired in self.model_configs: + return desired + # If configured default doesn't exist, fall back to first configured model + for k in self.model_configs.keys(): + return k + return desired def _get_model_config(self, model: str) -> Dict[str, str]: """Get configuration for a specific model""" if model in self.model_configs: return self.model_configs[model] - return self.model_configs["gpt-5-mini"] + # fallback to first configured model + for _, cfg in self.model_configs.items(): + return cfg + raise ValueError("No Azure OpenAI models are configured") def _get_official_client(self, model: str) -> AzureOpenAI: """Get official Azure OpenAI client instance""" - if model not in self._official_clients: - config = self._get_model_config(model) - if not config.get("endpoint"): - raise ValueError( - f"Azure OpenAI endpoint not configured for model {model}" - ) - - azure_openai_kwargs = { - "azure_endpoint": config["endpoint"], - "api_version": config["api_version"], + config = self._get_model_config(model) + endpoint = config.get("endpoint") + api_version = config.get("api_version") + if not endpoint or not api_version: + raise ValueError(f"Azure OpenAI endpoint/api_version not configured for model {model}") + + cache_key = (endpoint, api_version, self._auth_type) + if cache_key not in self._official_clients: + azure_openai_kwargs: Dict[str, Any] = { + "azure_endpoint": endpoint, + "api_version": api_version, } - if settings.USE_ENTRA_AUTH: + + if self._auth_type == "entra": + if not self._token_provider: + raise ValueError(self._config_error or "Azure AD token provider not configured") azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider - - if settings.AZURE_OPENAI_API_KEY: - azure_openai_kwargs["api_key"] = settings.AZURE_OPENAI_API_KEY + else: + # key auth + if not self._api_key: + raise ValueError("AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY") + azure_openai_kwargs["api_key"] = self._api_key - self._official_clients[model] = AzureOpenAI(**azure_openai_kwargs) + self._official_clients[cache_key] = AzureOpenAI(**azure_openai_kwargs) - return self._official_clients[model] + return self._official_clients[cache_key] def _build_messages( self, user_message: str, system_prompt: Optional[str] = None @@ -145,7 +261,9 @@ async def chat_completion( "stream": stream, } - if model != "gpt-5-mini": + # gpt-5 deployments may reject temperature/max_tokens in some previews. + # We gate this by the *deployment* name because the UI key can differ. + if deployment != "gpt-5-mini": request_kwargs["max_tokens"] = max_tokens request_kwargs["temperature"] = temperature @@ -219,16 +337,19 @@ async def streaming_chat( deployment = self._get_model_config(model)["deployment"] client = self._get_official_client(model) - response = client.chat.completions.create( - stream=True, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - frequency_penalty=0.0, - presence_penalty=0.0, - model=deployment, - ) + request_kwargs: Dict[str, Any] = { + "stream": True, + "messages": messages, + "top_p": top_p, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "model": deployment, + } + if deployment != "gpt-5-mini": + request_kwargs["max_tokens"] = max_tokens + request_kwargs["temperature"] = temperature + + response = client.chat.completions.create(**request_kwargs) for update in response: if update.choices: @@ -334,16 +455,40 @@ async def chat_with_context( def get_available_models(self) -> List[str]: """Get list of available models that are properly configured""" - return [ - model - for model, config in self.model_configs.items() - if config.get("endpoint") - ] + out: List[str] = [] + for model, config in self.model_configs.items(): + if not config.get("endpoint") or not config.get("deployment") or not config.get("api_version"): + continue + out.append(model) + return out def is_configured(self) -> bool: """Check if Azure OpenAI is properly configured""" - return len(self.get_available_models()) > 0 + if self._config_error: + return False + if not self.get_available_models(): + return False + + if self._auth_type == "key": + return bool(self._endpoint and self._api_key) + if self._auth_type == "entra": + return bool(self._endpoint and self._token_provider) + return False # Global Azure OpenAI client instance -azure_openai_client = AzureOpenAIClient() +# NOTE: This is used by routers. We intentionally avoid raising during import +# so the API can start up and report configuration issues as 503s. +try: + azure_openai_client = AzureOpenAIClient() +except Exception as e: # pragma: no cover + logger.exception("Failed to initialize AzureOpenAIClient: %s", e) + # Provide a stub that reports not-configured. + class _DisabledAzureOpenAIClient: # type: ignore + def is_configured(self) -> bool: + return False + + def get_available_models(self) -> List[str]: + return [] + + azure_openai_client = _DisabledAzureOpenAIClient() # type: ignore diff --git a/backend/api/services/cit_db_service.py b/backend/api/services/cit_db_service.py index b130b688..141e6e0e 100644 --- a/backend/api/services/cit_db_service.py +++ b/backend/api/services/cit_db_service.py @@ -138,7 +138,7 @@ def __init__(self): # ----------------------- # Generic column ops # ----------------------- - def create_column(self, db_conn_str: str, col: str, col_type: str, table_name: str = "citations") -> None: + def create_column(self, col: str, col_type: str, table_name: str = "citations") -> None: """ Create column on citations table if it doesn't already exist. col should be the exact column name to use (caller may pass snake_case(col)). @@ -165,7 +165,6 @@ def create_column(self, db_conn_str: str, col: str, col_type: str, table_name: s def update_jsonb_column( self, - db_conn_str: str, citation_id: int, col: str, data: Any, @@ -197,7 +196,6 @@ def update_jsonb_column( def update_text_column( self, - db_conn_str: str, citation_id: int, col: str, text_value: str, @@ -230,7 +228,7 @@ def update_text_column( # ----------------------- # Citation row helpers # ----------------------- - def dump_citations_csv(self, db_conn_str: str, table_name: str = "citations") -> bytes: + def dump_citations_csv(self, table_name: str = "citations") -> bytes: """Dump the entire `citations` table as CSV bytes. Intended to be called from async FastAPI routes via @@ -259,7 +257,7 @@ def dump_citations_csv(self, db_conn_str: str, table_name: str = "citations") -> if conn: pass - def get_citation_by_id(self, db_conn_str: str, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]: + def get_citation_by_id(self, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]: """ Return a dict mapping column -> value for the citation row, or None. """ @@ -286,7 +284,7 @@ def get_citation_by_id(self, db_conn_str: str, citation_id: int, table_name: str if conn: pass - def list_citation_ids(self, db_conn_str: str, filter_step=None, table_name: str = "citations") -> List[int]: + def list_citation_ids(self, filter_step=None, table_name: str = "citations") -> List[int]: """ Return list of integer primary keys (id) from citations table ordered by id. """ @@ -333,7 +331,7 @@ def list_citation_ids(self, db_conn_str: str, filter_step=None, table_name: str if conn: pass - def list_fulltext_urls(self, db_conn_str: str, table_name: str = "citations") -> List[str]: + def list_fulltext_urls(self, table_name: str = "citations") -> List[str]: """ Return list of fulltext_url values (non-null) from citations table. """ @@ -350,18 +348,17 @@ def list_fulltext_urls(self, db_conn_str: str, table_name: str = "citations") -> if conn: pass - def update_citation_fulltext(self, db_conn_str: str, citation_id: int, fulltext_path: str) -> int: + def update_citation_fulltext(self, citation_id: int, fulltext_path: str) -> int: """ Backwards-compatible helper used by some routers. Sets `fulltext_url`. """ - return self.update_text_column(db_conn_str, citation_id, "fulltext_url", fulltext_path) + return self.update_text_column(citation_id, "fulltext_url", fulltext_path) # ----------------------- # Upload fulltext and compute md5 # ----------------------- def attach_fulltext( self, - db_conn_str: str, citation_id: int, azure_path: str, file_bytes: bytes, @@ -373,7 +370,7 @@ def attach_fulltext( """ table_name = _validate_ident(table_name, kind="table_name") # create columns if missing - self.create_column(db_conn_str, "fulltext_url", "TEXT", table_name=table_name) + self.create_column("fulltext_url", "TEXT", table_name=table_name) # compute md5 md5 = hashlib.md5(file_bytes).hexdigest() if file_bytes is not None else "" @@ -383,15 +380,13 @@ def attach_fulltext( cur.execute(f'UPDATE "{table_name}" SET "fulltext_url" = %s WHERE id = %s', (azure_path, int(citation_id))) rows = cur.rowcount conn.commit() - - return rows # ----------------------- # Column get/set helpers # ----------------------- - def get_column_value(self, db_conn_str: str, citation_id: int, column: str, table_name: str = "citations") -> Any: + def get_column_value(self, citation_id: int, column: str, table_name: str = "citations") -> Any: """ Return the value stored in `column` for the citation row (or None). """ @@ -418,18 +413,18 @@ def get_column_value(self, db_conn_str: str, citation_id: int, column: str, tabl if conn: pass - def set_column_value(self, db_conn_str: str, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int: + def set_column_value(self, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int: """ Generic setter for a citation row column. Will create a TEXT column if it doesn't exist. """ # For simplicity, create a TEXT column. Callers that need JSONB should use update_jsonb_column. - self.create_column(db_conn_str, column, "TEXT", table_name=table_name) - return self.update_text_column(db_conn_str, citation_id, column, value if value is not None else None, table_name=table_name) + self.create_column(column, "TEXT", table_name=table_name) + return self.update_text_column(citation_id, column, value if value is not None else None, table_name=table_name) # ----------------------- # Per-upload table lifecycle helpers # ----------------------- - def drop_table(self, db_conn_str: str, table_name: str, cascade: bool = True) -> None: + def drop_table(self, table_name: str, cascade: bool = True) -> None: """Drop a screening table in the shared database.""" table_name = _validate_ident(table_name, kind="table_name") conn = None @@ -445,7 +440,6 @@ def drop_table(self, db_conn_str: str, table_name: str, cascade: bool = True) -> def create_table_and_insert_sync( self, - db_conn_str: str, table_name: str, columns: List[str], rows: List[Dict[str, Any]], diff --git a/backend/api/services/postgres_auth.py b/backend/api/services/postgres_auth.py index 9334c520..dc676cbd 100644 --- a/backend/api/services/postgres_auth.py +++ b/backend/api/services/postgres_auth.py @@ -1,17 +1,38 @@ -""" -PostgreSQL authentication helper using Azure Entra ID (DefaultAzureCredential). +"""backend.api.services.postgres_auth + +PostgreSQL connection helper supporting three runtime modes. + +Configuration model: +* POSTGRES_MODE selects behavior: docker | local | azure +* Connection settings are provided via a single set of env vars: + - POSTGRES_HOST + - POSTGRES_DATABASE + - POSTGRES_USER + - POSTGRES_PASSWORD + +Auth behavior: +* docker/local: password auth (POSTGRES_PASSWORD required) +* azure: Entra token auth via DefaultAzureCredential (password ignored) + +Behavior: +* Only try the configured POSTGRES_MODE (no fallback). -This module provides a centralized way to connect to Azure Database for PostgreSQL -using Entra ID authentication, with fallback to connection string for local development. +Notes: +* POSTGRES_URI is deprecated and intentionally not used. """ -from typing import Optional +import os +from typing import Optional, Dict, Any import psycopg2 from ..core.config import settings import logging import datetime -from azure.identity import DefaultAzureCredential + +try: + from azure.identity import DefaultAzureCredential +except Exception: # pragma: no cover + DefaultAzureCredential = None # type: ignore logger = logging.getLogger(__name__) @@ -24,7 +45,7 @@ class PostgresServer: def __init__(self): self._verify_config() - self._credential = DefaultAzureCredential() if settings.AZURE_DB else None + self._credential = DefaultAzureCredential() if self._mode() == "azure" else None self._token: Optional[str] = None self._token_expiration: int = 0 self._conn = None @@ -33,14 +54,11 @@ def __init__(self): def conn(self): """Return an open connection, reconnecting only when necessary.""" if self._conn is None or self._conn.closed: - print("local database") self._conn = self._connect() - elif settings.AZURE_DB and self._is_token_expired(): + elif self._mode() == "azure" and self._is_token_expired(): logger.info("Azure token expired — reconnecting to PostgreSQL") - print("cloud database") self.close() self._conn = self._connect() - print(self._conn) return self._conn def close(self): @@ -55,11 +73,22 @@ def close(self): @staticmethod def _verify_config(): """Validate that all required PostgreSQL settings are present.""" - required = [settings.POSTGRES_HOST, settings.POSTGRES_DATABASE, settings.POSTGRES_USER] + mode = (settings.POSTGRES_MODE or "").lower().strip() + if mode not in {"local", "docker", "azure"}: + raise RuntimeError("POSTGRES_MODE must be one of: local, docker, azure") + + # Validate selected profile minimally; the rest is validated when building kwargs. + try: + prof = settings.postgres_profile(mode) + except Exception as e: + raise RuntimeError(str(e)) + + required = [prof.get("host"), prof.get("database"), prof.get("user")] if not all(required): - raise RuntimeError("POSTGRES_HOST, POSTGRES_DATABASE, and POSTGRES_USER are required") - if not settings.AZURE_DB and not settings.POSTGRES_PASSWORD: - raise RuntimeError("POSTGRES_PASSWORD is required when AZURE_DB is False") + raise RuntimeError(f"{mode} profile requires host, database and user") + + if mode in {"docker", "local"} and not prof.get("password"): + raise RuntimeError(f"{mode} mode requires POSTGRES_PASSWORD") def _is_token_expired(self) -> bool: """Check whether the cached Azure token needs refreshing.""" @@ -70,36 +99,76 @@ def _refresh_azure_token(self) -> str: """Return a valid Azure token, fetching a new one only if expired.""" if self._is_token_expired(): logger.info("Fetching fresh Azure PostgreSQL token") + if not self._credential: + raise RuntimeError( + "Azure credential is not configured. Ensure POSTGRES_MODE=azure and that " + "DefaultAzureCredential can authenticate in this environment." + ) token = self._credential.get_token(self._AZURE_POSTGRES_SCOPE) self._token = token.token self._token_expiration = token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS return self._token - def _build_connect_kwargs(self) -> dict: - """Assemble psycopg2.connect() keyword arguments from settings.""" - kwargs = { - "host": settings.POSTGRES_HOST, - "database": settings.POSTGRES_DATABASE, - "user": settings.POSTGRES_USER, - "port": settings.POSTGRES_PORT, + @staticmethod + def _mode() -> str: + return (settings.POSTGRES_MODE or "docker").lower().strip() + + @staticmethod + def _has_local_fallback() -> bool: + return False + + def _candidate_kwargs(self, mode: str) -> Dict[str, Any]: + """Build connect kwargs for a given mode based on POSTGRES_* env vars.""" + prof = settings.postgres_profile(mode) + + kwargs: Dict[str, Any] = { + "host": prof.get("host"), + "database": prof.get("database"), + "user": prof.get("user"), + "port": int(prof.get("port") or 5432), + # Fail fast so connection errors surface quickly + "connect_timeout": int(os.getenv("POSTGRES_CONNECT_TIMEOUT", "3")), } - if settings.POSTGRES_SSL_MODE: - kwargs["sslmode"] = settings.POSTGRES_SSL_MODE - if settings.AZURE_DB: + + sslmode = prof.get("sslmode") + if sslmode: + kwargs["sslmode"] = sslmode + + if prof.get("mode") == "azure": kwargs["password"] = self._refresh_azure_token() - elif settings.POSTGRES_PASSWORD: - kwargs["password"] = settings.POSTGRES_PASSWORD + else: + if not prof.get("password"): + raise RuntimeError(f"{mode} profile requires password") + kwargs["password"] = prof.get("password") + + # Sanity checks + required = [kwargs.get("host"), kwargs.get("database"), kwargs.get("user"), kwargs.get("port")] + if not all(required): + raise RuntimeError(f"Incomplete Postgres config for mode={mode}") + return kwargs + def _connect_with_mode(self, mode: str): + kwargs = self._candidate_kwargs(mode) + safe_kwargs = {k: ("***" if k == "password" else v) for k, v in kwargs.items()} + logger.info("Connecting to Postgres (mode=%s) %s", mode, safe_kwargs) + return psycopg2.connect(**kwargs) + def _connect(self): """Create a new psycopg2 connection.""" - return psycopg2.connect(**self._build_connect_kwargs()) + primary_mode = self._mode() + try: + return self._connect_with_mode(primary_mode) + except Exception as e: + logger.error("Postgres connect failed (mode=%s): %s", primary_mode, e, exc_info=True) + raise psycopg2.OperationalError( + f"Could not connect to Postgres for mode={primary_mode}" + ) def __repr__(self) -> str: status = "open" if self._conn and not self._conn.closed else "closed" return ( - f"" + f"" ) postgres_server = PostgresServer() \ No newline at end of file diff --git a/backend/api/services/sr_db_service.py b/backend/api/services/sr_db_service.py index 1dfaedce..4b4ccf14 100644 --- a/backend/api/services/sr_db_service.py +++ b/backend/api/services/sr_db_service.py @@ -27,7 +27,7 @@ def __init__(self): # Service is stateless; connection strings passed per-call pass - def ensure_table_exists(self, db_conn_str: str) -> None: + def ensure_table_exists(self) -> None: """ Ensure the systematic_reviews table exists in PostgreSQL. Creates the table if it doesn't exist. @@ -57,8 +57,6 @@ def ensure_table_exists(self, db_conn_str: str) -> None: """ cur.execute(create_table_sql) conn.commit() - - logger.info("Ensured systematic_reviews table exists") except Exception as e: @@ -157,7 +155,6 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[ def create_systematic_review( self, - db_conn_str: str, name: str, description: Optional[str], criteria_str: Optional[str], @@ -236,7 +233,7 @@ def create_systematic_review( if conn: pass - def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: + def add_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: """ Add a user id to the SR's users list. Enforces that the SR exists and is visible; requester must be a member or owner. @@ -244,12 +241,12 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ """ - sr = self.get_systematic_review(db_conn_str, sr_id) + sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") # Check permission - has_perm = self.user_has_sr_permission(db_conn_str, sr_id, requester_id) + has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") @@ -294,19 +291,19 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ if conn: pass - def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: + def remove_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: """ Remove a user id from the SR's users list. Owner cannot be removed. Enforces requester permissions (must be a member or owner). """ - sr = self.get_systematic_review(db_conn_str, sr_id) + sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") # Check permission - has_perm = self.user_has_sr_permission(db_conn_str, sr_id, requester_id) + has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") @@ -354,7 +351,7 @@ def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, request if conn: pass - def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> bool: + def user_has_sr_permission(self, sr_id: str, user_id: str) -> bool: """ Check whether the given user_id is a member (in 'users') or the owner of the SR. Returns True if the SR exists and the user is present in the SR's users list or is the owner. @@ -363,7 +360,7 @@ def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> """ - doc = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) + doc = self.get_systematic_review(sr_id, ignore_visibility=True) if not doc: return False @@ -372,7 +369,7 @@ def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> return True return False - def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, Any], criteria_str: str, requester_id: str) -> Dict[str, Any]: + def update_criteria(self, sr_id: str, criteria_obj: Dict[str, Any], criteria_str: str, requester_id: str) -> Dict[str, Any]: """ Update the criteria fields (criteria, criteria_yaml, criteria_parsed, updated_at). The requester must be a member or owner. @@ -380,12 +377,12 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, """ - sr = self.get_systematic_review(db_conn_str, sr_id) + sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") # Check permission - has_perm = self.user_has_sr_permission(db_conn_str, sr_id, requester_id) + has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") @@ -419,7 +416,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, # Return fresh doc - doc = self.get_systematic_review(db_conn_str, sr_id) + doc = self.get_systematic_review(sr_id) return doc except HTTPException: @@ -431,7 +428,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, if conn: pass - def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> List[Dict[str, Any]]: + def list_systematic_reviews_for_user(self, user_email: str) -> List[Dict[str, Any]]: """ Return all SR documents where the user is a member (regardless of visible flag). """ @@ -484,7 +481,7 @@ def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> if conn: pass - def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]: + def get_systematic_review(self, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]: """ Return SR document by id. Returns None if not found. If ignore_visibility is False, only returns visible SRs. @@ -535,14 +532,14 @@ def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: if conn: pass - def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]: + def set_visibility(self, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]: """ Set the visible flag on the SR. Only owner is allowed to change visibility. Returns update metadata. """ - sr = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) + sr = self.get_systematic_review(sr_id, ignore_visibility=True) if not sr: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") @@ -573,26 +570,26 @@ def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_ if conn: pass - def soft_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: + def soft_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: """ Soft-delete (set visible=False). Only owner may delete. """ - return self.set_visibility(db_conn_str, sr_id, False, requester_id) + return self.set_visibility(sr_id, False, requester_id) - def undelete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: + def undelete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: """ Undelete (set visible=True). Only owner may undelete. """ - return self.set_visibility(db_conn_str, sr_id, True, requester_id) + return self.set_visibility(sr_id, True, requester_id) - def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: + def hard_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: """ Permanently remove the SR document. Only owner may hard delete. Returns deletion metadata (deleted_count). """ - sr = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) + sr = self.get_systematic_review(sr_id, ignore_visibility=True) if not sr: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") @@ -618,7 +615,7 @@ def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_ pass - def update_screening_db_info(self, db_conn_str: str, sr_id: str, screening_db: Dict[str, Any]) -> None: + def update_screening_db_info(self, sr_id: str, screening_db: Dict[str, Any]) -> None: """ Update the screening_db field in the SR document with screening database metadata. """ @@ -645,7 +642,7 @@ def update_screening_db_info(self, db_conn_str: str, sr_id: str, screening_db: D if conn: pass - def clear_screening_db_info(self, db_conn_str: str, sr_id: str) -> None: + def clear_screening_db_info(self, sr_id: str) -> None: """ Remove the screening_db field from the SR document. """ diff --git a/backend/api/services/storage.py b/backend/api/services/storage.py index b7595416..b5d4f2b5 100644 --- a/backend/api/services/storage.py +++ b/backend/api/services/storage.py @@ -1,53 +1,96 @@ -"""Azure Blob Storage service for user data management""" +"""backend.api.services.storage + +Storage abstraction for CAN-SR. + +Supported backends (selected via STORAGE_MODE): +* local - Local filesystem storage (backed by docker compose volume) +* azure - Azure Blob Storage via **account name + key** (strict) +* entra - Azure Blob Storage via **DefaultAzureCredential** (Entra/Managed Identity) (strict) + +Routers should not access Azure SDK objects directly. +""" + +from __future__ import annotations import json import logging +import os import uuid from datetime import datetime, timezone -from typing import Dict, List, Optional, Any - -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient -from azure.core.exceptions import ResourceNotFoundError +from pathlib import Path +from typing import Any, Dict, List, Optional, Protocol, Tuple + +try: + from azure.core.exceptions import ResourceNotFoundError + from azure.identity import DefaultAzureCredential + from azure.storage.blob import BlobServiceClient +except Exception: # pragma: no cover + # Allow local-storage deployments/environments to import without azure packages. + ResourceNotFoundError = Exception # type: ignore + DefaultAzureCredential = None # type: ignore + BlobServiceClient = None # type: ignore from ..core.config import settings -from ..utils.file_hash import calculate_file_hash, create_file_metadata +from ..utils.file_hash import create_file_metadata logger = logging.getLogger(__name__) +class StorageService(Protocol): + """Common API that both Azure and local storage must implement.""" + + container_name: str + + async def create_user_directory(self, user_id: str) -> bool: ... + async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: ... + async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: ... + async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: ... + async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: ... + async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: ... + async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: ... + async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: ... + async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: ... + async def delete_by_path(self, path: str) -> bool: ... + + +# ============================================================================= +# Azure Blob Storage +# ============================================================================= + + class AzureStorageService: - """Service for managing user data in Azure Blob Storage""" + """Service for managing user data in Azure Blob Storage.""" - def __init__(self): - if not settings.AZURE_STORAGE_ACCOUNT_NAME and not settings.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError("AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be configured") + def __init__(self, *, account_url: str | None = None, connection_string: str | None = None, container_name: str): + if not BlobServiceClient: + raise RuntimeError( + "Azure storage libraries are not installed. Install azure-identity and azure-storage-blob, or use STORAGE_MODE=local." + ) - if settings.AZURE_STORAGE_ACCOUNT_NAME: - account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + if bool(account_url) == bool(connection_string): + raise ValueError("Exactly one of account_url or connection_string must be provided") + + if connection_string: + self.blob_service_client = BlobServiceClient.from_connection_string(connection_string) + else: + if not DefaultAzureCredential: + raise RuntimeError( + "azure-identity is not installed. Install azure-identity, or use STORAGE_MODE=azure (connection string) or local." + ) credential = DefaultAzureCredential() - self.blob_service_client = BlobServiceClient( - account_url=account_url, credential=credential - ) - elif settings.AZURE_STORAGE_CONNECTION_STRING: - self.blob_service_client = BlobServiceClient.from_connection_string( - settings.AZURE_STORAGE_CONNECTION_STRING - ) - self.container_name = settings.AZURE_STORAGE_CONTAINER_NAME + self.blob_service_client = BlobServiceClient(account_url=account_url, credential=credential) + self.container_name = container_name self._ensure_container_exists() def _ensure_container_exists(self): - """Ensure the storage container exists""" try: self.blob_service_client.create_container(self.container_name) except Exception: pass async def create_user_directory(self, user_id: str) -> bool: - """Create directory structure for a new user""" try: - # Create user profile profile_data = { "user_id": user_id, "created_at": datetime.now(timezone.utc).isoformat(), @@ -59,70 +102,43 @@ async def create_user_directory(self, user_id: str) -> bool: await self.save_user_profile(user_id, profile_data) # Create placeholder file to establish directory structure - directories = [f"users/{user_id}/documents/"] - - for directory in directories: - blob_name = f"{directory}.placeholder" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - blob_client.upload_blob("", overwrite=True) - + blob_name = f"users/{user_id}/documents/.placeholder" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.upload_blob(b"", overwrite=True) return True except Exception as e: - logger.error(f"Error creating user directory for {user_id}: {e}") + logger.error("Error creating user directory for %s: %s", user_id, e) return False - async def save_user_profile( - self, user_id: str, profile_data: Dict[str, Any] - ) -> bool: - """Save user profile data""" + async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: try: blob_name = f"users/{user_id}/profile.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - profile_json = json.dumps(profile_data, indent=2) - blob_client.upload_blob(profile_json, overwrite=True) + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.upload_blob(json.dumps(profile_data, indent=2), overwrite=True) return True except Exception as e: - logger.error(f"Error saving user profile for {user_id}: {e}") + logger.error("Error saving user profile for %s: %s", user_id, e) return False async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get user profile data""" try: blob_name = f"users/{user_id}/profile.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) blob_data = blob_client.download_blob().readall() return json.loads(blob_data.decode("utf-8")) except ResourceNotFoundError: return None except Exception as e: - logger.error(f"Error getting user profile for {user_id}: {e}") + logger.error("Error getting user profile for %s: %s", user_id, e) return None - async def upload_user_document( - self, user_id: str, filename: str, file_content: bytes - ) -> Optional[str]: - """Upload a document for a user with hash metadata for duplicate detection""" + async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: try: - # Generate unique document ID doc_id = str(uuid.uuid4()) blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - # Upload the file + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) blob_client.upload_blob(file_content, overwrite=True) - # Create and save file metadata with hash file_metadata = create_file_metadata( filename, file_content, @@ -132,235 +148,447 @@ async def upload_user_document( "upload_date": datetime.now(timezone.utc).isoformat(), }, ) - await self.save_file_hash_metadata(user_id, doc_id, file_metadata) - # Update user profile profile = await self.get_user_profile(user_id) if profile: - profile["document_count"] += 1 - profile["storage_used"] += len(file_content) + profile["document_count"] = int(profile.get("document_count", 0)) + 1 + profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return doc_id except Exception as e: - logger.error(f"Error uploading document {filename} for user {user_id}: {e}") + logger.error("Error uploading document %s for user %s: %s", filename, user_id, e) return None - async def get_user_document( - self, user_id: str, doc_id: str, filename: str - ) -> Optional[bytes]: - """Get a user's document""" + async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: try: blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) return blob_client.download_blob().readall() except ResourceNotFoundError: return None except Exception as e: - logger.error(f"Error getting document {doc_id} for user {user_id}: {e}") + logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e) return None async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: - """List all documents for a user - SIMPLIFIED""" try: prefix = f"users/{user_id}/documents/" - blobs = self.blob_service_client.get_container_client( - self.container_name - ).list_blobs(name_starts_with=prefix) + blobs = self.blob_service_client.get_container_client(self.container_name).list_blobs(name_starts_with=prefix) - documents = [] + documents: List[Dict[str, Any]] = [] for blob in blobs: - if not blob.name.endswith(".placeholder"): - # Extract doc_id and filename from blob name - blob_filename = blob.name.replace(prefix, "") - if "_" in blob_filename: - doc_id, filename = blob_filename.split("_", 1) - - # Get hash metadata if available - hash_metadata = await self.get_file_hash_metadata( - user_id, doc_id - ) - - document_info = { - "document_id": doc_id, - "filename": filename, - "file_size": blob.size, - "upload_date": blob.last_modified.isoformat(), - "last_modified": blob.last_modified.isoformat(), - } - - # Add hash information if available - if hash_metadata: - document_info["file_hash"] = hash_metadata.get("file_hash") - document_info["signature"] = hash_metadata.get("signature") - - documents.append(document_info) + if blob.name.endswith(".placeholder"): + continue + blob_filename = blob.name.replace(prefix, "") + if "_" not in blob_filename: + continue + doc_id, filename = blob_filename.split("_", 1) + + hash_metadata = await self.get_file_hash_metadata(user_id, doc_id) + document_info: Dict[str, Any] = { + "document_id": doc_id, + "filename": filename, + "file_size": blob.size, + "upload_date": blob.last_modified.isoformat(), + "last_modified": blob.last_modified.isoformat(), + } + if hash_metadata: + document_info["file_hash"] = hash_metadata.get("file_hash") + document_info["signature"] = hash_metadata.get("signature") + documents.append(document_info) return documents except Exception as e: - logger.error(f"Error listing user documents for {user_id}: {e}") + logger.error("Error listing user documents for %s: %s", user_id, e) return [] - async def delete_user_document( - self, user_id: str, doc_id: str, filename: str - ) -> bool: - """Delete a user's document""" + async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: try: - # Get document size before deletion for profile update doc_blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - doc_blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=doc_blob_name - ) + doc_blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=doc_blob_name) - # Get blob properties to determine size try: - blob_properties = doc_blob_client.get_blob_properties() - doc_size = blob_properties.size + doc_size = doc_blob_client.get_blob_properties().size except ResourceNotFoundError: doc_size = 0 - # Delete the document doc_blob_client.delete_blob() - - # Delete the associated hash metadata await self.delete_file_hash_metadata(user_id, doc_id) - # Update user profile profile = await self.get_user_profile(user_id) if profile: - profile["document_count"] = max(0, profile["document_count"] - 1) - profile["storage_used"] = max(0, profile["storage_used"] - doc_size) + profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1) + profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size)) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return True except Exception as e: - logger.error(f"Error deleting document {doc_id} for user {user_id}: {e}") + logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e) return False - async def calculate_user_storage_usage(self, user_id: str) -> int: - """Calculate actual storage usage for a user""" + async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool: try: - prefix = f"users/{user_id}/documents/" - blobs = self.blob_service_client.get_container_client( - self.container_name - ).list_blobs(name_starts_with=prefix) + blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.upload_blob(json.dumps(file_metadata, indent=2), overwrite=True) + return True + except Exception as e: + logger.error("Error saving file metadata for %s: %s", document_id, e) + return False - total_size = 0 - for blob in blobs: - if not blob.name.endswith(".placeholder"): - total_size += blob.size or 0 + async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]: + try: + blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + metadata_json = blob_client.download_blob().readall().decode("utf-8") + return json.loads(metadata_json) + except ResourceNotFoundError: + return None + except Exception as e: + logger.error("Error getting file metadata for %s: %s", document_id, e) + return None - return total_size + async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: + try: + blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.delete_blob() + return True + except ResourceNotFoundError: + return True except Exception as e: - logger.error(f"Error calculating storage usage for user {user_id}: {e}") - return 0 + logger.error("Error deleting file metadata for %s: %s", document_id, e) + return False + + async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: + """Write blob by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client.upload_blob(content, overwrite=True, content_type=content_type) + return True + + async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: + """Read blob by storage path 'container/blob'. Returns (bytes, filename).""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + content = blob_client.download_blob().readall() + filename = os.path.basename(blob) or "download" + return content, filename + + async def delete_by_path(self, path: str) -> bool: + """Delete blob by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client.delete_blob() + return True - async def sync_user_profile_stats(self, user_id: str) -> bool: - """Synchronize user profile statistics with actual storage""" + +# ============================================================================= +# Local filesystem storage +# ============================================================================= + + +class LocalStorageService: + """Local filesystem storage implementation. + + Layout: + {LOCAL_STORAGE_BASE_PATH}/{STORAGE_CONTAINER_NAME}/users/{user_id}/... + """ + + def __init__(self): + self.base_path = Path(settings.LOCAL_STORAGE_BASE_PATH).resolve() + self.container_name = settings.STORAGE_CONTAINER_NAME + (self.base_path / self.container_name).mkdir(parents=True, exist_ok=True) + + def _container_root(self) -> Path: + return self.base_path / self.container_name + + def _user_root(self, user_id: str) -> Path: + return self._container_root() / "users" / str(user_id) + + def _profile_path(self, user_id: str) -> Path: + return self._user_root(user_id) / "profile.json" + + def _doc_path(self, user_id: str, doc_id: str, filename: str) -> Path: + return self._user_root(user_id) / "documents" / f"{doc_id}_{filename}" + + def _metadata_path(self, user_id: str, doc_id: str) -> Path: + return self._user_root(user_id) / "metadata" / f"{doc_id}_metadata.json" + + async def create_user_directory(self, user_id: str) -> bool: try: - documents = await self.list_user_documents(user_id) - actual_storage = await self.calculate_user_storage_usage(user_id) + (self._user_root(user_id) / "documents").mkdir(parents=True, exist_ok=True) + (self._user_root(user_id) / "metadata").mkdir(parents=True, exist_ok=True) - profile = await self.get_user_profile(user_id) - if profile: - profile["document_count"] = len(documents) - profile["storage_used"] = actual_storage - profile["last_updated"] = datetime.now(timezone.utc).isoformat() - await self.save_user_profile(user_id, profile) - return True - return False + if not self._profile_path(user_id).exists(): + profile_data = { + "user_id": user_id, + "created_at": datetime.now(timezone.utc).isoformat(), + "last_updated": datetime.now(timezone.utc).isoformat(), + "document_count": 0, + "storage_used": 0, + } + await self.save_user_profile(user_id, profile_data) + return True except Exception as e: - logger.error(f"Error syncing profile stats for user {user_id}: {e}") + logger.error("Error creating user directory for %s: %s", user_id, e) return False - async def save_file_hash_metadata( - self, user_id: str, document_id: str, file_metadata: Dict[str, Any] - ) -> bool: - """Save file hash metadata for duplicate detection""" + async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: try: - blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - metadata_json = json.dumps(file_metadata, indent=2) - blob_client.upload_blob(metadata_json, overwrite=True) + self._profile_path(user_id).parent.mkdir(parents=True, exist_ok=True) + self._profile_path(user_id).write_text(json.dumps(profile_data, indent=2), encoding="utf-8") return True except Exception as e: - logger.error(f"Error saving file metadata for {document_id}: {e}") + logger.error("Error saving user profile for %s: %s", user_id, e) return False - async def get_file_hash_metadata( - self, user_id: str, document_id: str - ) -> Optional[Dict[str, Any]]: - """Get file hash metadata for a specific document""" + async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: + try: + p = self._profile_path(user_id) + if not p.exists(): + return None + return json.loads(p.read_text(encoding="utf-8")) + except Exception as e: + logger.error("Error getting user profile for %s: %s", user_id, e) + return None + + async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: try: - blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name + await self.create_user_directory(user_id) + doc_id = str(uuid.uuid4()) + doc_path = self._doc_path(user_id, doc_id, filename) + doc_path.parent.mkdir(parents=True, exist_ok=True) + doc_path.write_bytes(file_content) + + file_metadata = create_file_metadata( + filename, + file_content, + { + "document_id": doc_id, + "user_id": user_id, + "upload_date": datetime.now(timezone.utc).isoformat(), + }, ) + await self.save_file_hash_metadata(user_id, doc_id, file_metadata) - metadata_json = blob_client.download_blob().readall().decode("utf-8") - return json.loads(metadata_json) - except ResourceNotFoundError: - return None + profile = await self.get_user_profile(user_id) + if profile: + profile["document_count"] = int(profile.get("document_count", 0)) + 1 + profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content) + profile["last_updated"] = datetime.now(timezone.utc).isoformat() + await self.save_user_profile(user_id, profile) + + return doc_id except Exception as e: - logger.error(f"Error getting file metadata for {document_id}: {e}") + logger.error("Error uploading document %s for user %s: %s", filename, user_id, e) return None - async def get_all_user_file_hashes(self, user_id: str) -> List[Dict[str, Any]]: - """Get all file hash metadata for a user for duplicate detection""" + async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: try: - prefix = f"users/{user_id}/metadata/" - blobs = self.blob_service_client.get_container_client( - self.container_name - ).list_blobs(name_starts_with=prefix) + p = self._doc_path(user_id, doc_id, filename) + if not p.exists(): + return None + return p.read_bytes() + except Exception as e: + logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e) + return None - all_metadata = [] - for blob in blobs: - if blob.name.endswith("_metadata.json"): - try: - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob.name - ) - metadata_json = ( - blob_client.download_blob().readall().decode("utf-8") - ) - metadata = json.loads(metadata_json) - all_metadata.append(metadata) - except Exception as e: - logger.warning(f"Error reading metadata from {blob.name}: {e}") - continue - - return all_metadata + async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: + try: + docs_dir = self._user_root(user_id) / "documents" + if not docs_dir.exists(): + return [] + + documents: List[Dict[str, Any]] = [] + for p in docs_dir.iterdir(): + if not p.is_file(): + continue + name = p.name + if "_" not in name: + continue + doc_id, filename = name.split("_", 1) + stat = p.stat() + + hash_metadata = await self.get_file_hash_metadata(user_id, doc_id) + doc_info: Dict[str, Any] = { + "document_id": doc_id, + "filename": filename, + "file_size": stat.st_size, + "upload_date": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), + "last_modified": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), + } + if hash_metadata: + doc_info["file_hash"] = hash_metadata.get("file_hash") + doc_info["signature"] = hash_metadata.get("signature") + documents.append(doc_info) + + # stable ordering + documents.sort(key=lambda d: d.get("upload_date", ""), reverse=True) + return documents except Exception as e: - logger.error(f"Error getting all file hashes for user {user_id}: {e}") + logger.error("Error listing user documents for %s: %s", user_id, e) return [] - async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: - """Delete file hash metadata when document is deleted""" + async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: try: - blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - blob_client.delete_blob() + p = self._doc_path(user_id, doc_id, filename) + doc_size = p.stat().st_size if p.exists() else 0 + if p.exists(): + p.unlink() + await self.delete_file_hash_metadata(user_id, doc_id) + + profile = await self.get_user_profile(user_id) + if profile: + profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1) + profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size)) + profile["last_updated"] = datetime.now(timezone.utc).isoformat() + await self.save_user_profile(user_id, profile) return True - except ResourceNotFoundError: - # Metadata doesn't exist, which is fine + except Exception as e: + logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e) + return False + + async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool: + try: + p = self._metadata_path(user_id, document_id) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(json.dumps(file_metadata, indent=2), encoding="utf-8") return True except Exception as e: - logger.error(f"Error deleting file metadata for {document_id}: {e}") + logger.error("Error saving file metadata for %s: %s", document_id, e) return False + async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]: + try: + p = self._metadata_path(user_id, document_id) + if not p.exists(): + return None + return json.loads(p.read_text(encoding="utf-8")) + except Exception as e: + logger.error("Error getting file metadata for %s: %s", document_id, e) + return None + + async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: + try: + p = self._metadata_path(user_id, document_id) + if p.exists(): + p.unlink() + return True + except Exception as e: + logger.error("Error deleting file metadata for %s: %s", document_id, e) + return False + + async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: + """Write file by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + if container != self.container_name: + raise FileNotFoundError("Container not found") + + p = (self.base_path / container / blob).resolve() + # Prevent path traversal + if not str(p).startswith(str((self.base_path / container).resolve())): + raise FileNotFoundError("Invalid path") + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return True + + async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: + """Read file by storage path 'container/blob'. Returns (bytes, filename).""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + # Only allow access to our configured local container. + if container != self.container_name: + raise FileNotFoundError("Container not found") + + p = (self.base_path / container / blob).resolve() + # Prevent path traversal + if not str(p).startswith(str((self.base_path / container).resolve())): + raise FileNotFoundError("Invalid path") + if not p.exists() or not p.is_file(): + raise FileNotFoundError("File not found") + + return p.read_bytes(), (p.name or "download") + + async def delete_by_path(self, path: str) -> bool: + """Delete file by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + if container != self.container_name: + raise FileNotFoundError("Container not found") + + p = (self.base_path / container / blob).resolve() + if not str(p).startswith(str((self.base_path / container).resolve())): + raise FileNotFoundError("Invalid path") + if not p.exists() or not p.is_file(): + raise FileNotFoundError("File not found") + p.unlink() + return True + + +# ============================================================================= +# Factory +# ============================================================================= + + +def _build_storage_service() -> Optional[StorageService]: + stype = (settings.STORAGE_MODE or "azure").lower().strip() + if stype == "local": + try: + return LocalStorageService() + except Exception as e: + logger.exception("Failed to initialize LocalStorageService: %s", e) + return None + if stype == "azure": + try: + if not settings.AZURE_STORAGE_ACCOUNT_NAME or not settings.AZURE_STORAGE_ACCOUNT_KEY: + raise ValueError("STORAGE_MODE=azure requires AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY") + connection_string = ( + "DefaultEndpointsProtocol=https;" + f"AccountName={settings.AZURE_STORAGE_ACCOUNT_NAME};" + f"AccountKey={settings.AZURE_STORAGE_ACCOUNT_KEY};" + "EndpointSuffix=core.windows.net" + ) + return AzureStorageService( + connection_string=connection_string, + container_name=settings.STORAGE_CONTAINER_NAME, + ) + except Exception as e: + logger.exception("Failed to initialize AzureStorageService (connection string): %s", e) + return None + if stype == "entra": + try: + if not settings.AZURE_STORAGE_ACCOUNT_NAME: + raise ValueError("STORAGE_MODE=entra requires AZURE_STORAGE_ACCOUNT_NAME") + account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + return AzureStorageService( + account_url=account_url, + container_name=settings.STORAGE_CONTAINER_NAME, + ) + except Exception as e: + logger.exception("Failed to initialize AzureStorageService (Entra): %s", e) + return None + + logger.warning("Unsupported STORAGE_MODE=%s; storage disabled", stype) + return None + -# Global storage service instance -storage_service = ( - AzureStorageService() if settings.AZURE_STORAGE_ACCOUNT_NAME else None -) +storage_service: Optional[StorageService] = _build_storage_service() diff --git a/backend/api/services/user_db.py b/backend/api/services/user_db.py index 00300523..0624b084 100644 --- a/backend/api/services/user_db.py +++ b/backend/api/services/user_db.py @@ -1,102 +1,81 @@ -"""User database service using Azure Blob Storage""" +"""backend.api.services.user_db + +User database service. + +Historically this project stored users inside Azure Blob Storage directly. +To support multiple storage backends (local / azure / entra) we now build the +user DB on top of the common `storage_service` abstraction. + +Storage keys used: + system/user_registry.json + +This file intentionally avoids importing Azure SDK packages so that local +deployments can run without them. +""" + +from __future__ import annotations import json import uuid from datetime import datetime -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient -from azure.core.exceptions import ResourceNotFoundError from passlib.context import CryptContext -from ..core.config import settings from ..models.auth import UserCreate, UserRead +from .storage import storage_service class UserDatabaseService: - """Service for managing user data in Azure Blob Storage""" + """Service for managing user data in the configured storage backend.""" def __init__(self): - if not settings.AZURE_STORAGE_ACCOUNT_NAME and not settings.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError("AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be configured") - - if settings.AZURE_STORAGE_ACCOUNT_NAME: - account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" - credential = DefaultAzureCredential() - self.blob_service_client = BlobServiceClient( - account_url=account_url, credential=credential - ) - elif settings.AZURE_STORAGE_CONNECTION_STRING: - self.blob_service_client = BlobServiceClient.from_connection_string( - settings.AZURE_STORAGE_CONNECTION_STRING + if not storage_service: + raise RuntimeError( + "Storage is not configured. User database is unavailable." ) - - self.container_name = settings.AZURE_STORAGE_CONTAINER_NAME + self.storage = storage_service self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - self._ensure_container_exists() - - def _ensure_container_exists(self): - """Ensure the storage container exists""" - try: - self.blob_service_client.create_container(self.container_name) - except Exception: - pass - def _get_password_hash(self, password: str) -> str: - """Get password hash""" return self.pwd_context.hash(password) def _verify_password(self, plain_password: str, hashed_password: str) -> bool: - """Verify a password against a hash""" return self.pwd_context.verify(plain_password, hashed_password) + def _registry_path(self) -> str: + # Keep legacy layout: /system/user_registry.json + return f"{self.storage.container_name}/system/user_registry.json" + async def _load_user_registry(self) -> Dict[str, Any]: - """Load the user registry from blob storage""" + """Load the user registry from storage.""" try: - blob_name = "system/user_registry.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - blob_data = blob_client.download_blob().readall() - return json.loads(blob_data.decode("utf-8")) - except ResourceNotFoundError: - # Create empty registry if it doesn't exist - return {"users": {}, "email_index": {}} - except Exception as e: - print(f"Error loading user registry: {e}") + content, _filename = await self.storage.get_bytes_by_path(self._registry_path()) + return json.loads(content.decode("utf-8")) + except Exception: + # Create empty registry if it doesn't exist / cannot be read return {"users": {}, "email_index": {}} async def _save_user_registry(self, registry: Dict[str, Any]) -> bool: - """Save the user registry to blob storage""" + """Save the user registry to storage.""" try: - blob_name = "system/user_registry.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name + payload = json.dumps(registry, indent=2).encode("utf-8") + return await self.storage.put_bytes_by_path( + self._registry_path(), + payload, + content_type="application/json", ) - - registry_json = json.dumps(registry, indent=2) - blob_client.upload_blob(registry_json, overwrite=True) - return True - except Exception as e: - print(f"Error saving user registry: {e}") + except Exception: return False async def create_user(self, user_data: UserCreate) -> Optional[UserRead]: - """Create a new user""" try: registry = await self._load_user_registry() - # Check if user already exists if user_data.email in registry["email_index"]: - return None # User already exists + return None - # Generate unique user ID user_id = str(uuid.uuid4()) - - # Create user record user_record = { "id": user_id, "email": user_data.email, @@ -109,89 +88,59 @@ async def create_user(self, user_data: UserCreate) -> Optional[UserRead]: "last_login": None, } - # Add to registry registry["users"][user_id] = user_record registry["email_index"][user_data.email] = user_id - # Save registry - if await self._save_user_registry(registry): - # Create user directory structure in storage - from .storage import storage_service - # from .dual_milvus_manager import dual_milvus_manager - # from .base_knowledge_manager import base_knowledge_manager - - if storage_service: - await storage_service.create_user_directory(user_id) - - # User collections will be created on first login via auth router - - # Note: Base knowledge is initialized once at startup, not per user + if not await self._save_user_registry(registry): + return None - return UserRead( - id=user_id, - email=user_record["email"], - full_name=user_record["full_name"], - is_active=user_record["is_active"], - is_superuser=user_record["is_superuser"], - created_at=user_record["created_at"], - ) + await self.storage.create_user_directory(user_id) - return None - except Exception as e: - print(f"Error creating user: {e}") + return UserRead( + id=user_id, + email=user_record["email"], + full_name=user_record["full_name"], + is_active=user_record["is_active"], + is_superuser=user_record["is_superuser"], + created_at=user_record["created_at"], + ) + except Exception: return None async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: - """Get a user by email""" try: registry = await self._load_user_registry() - if email not in registry["email_index"]: return None - user_id = registry["email_index"][email] return registry["users"].get(user_id) - except Exception as e: - print(f"Error getting user by email: {e}") + except Exception: return None async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get a user by ID""" try: registry = await self._load_user_registry() return registry["users"].get(user_id) - except Exception as e: - print(f"Error getting user by ID: {e}") + except Exception: return None - async def authenticate_user( - self, email: str, password: str, sso: bool - ) -> Optional[Dict[str, Any]]: - """Authenticate a user""" + async def authenticate_user(self, email: str, password: str, sso: bool) -> Optional[Dict[str, Any]]: try: user = await self.get_user_by_email(email) if not user: return None - if not sso and not self._verify_password(password, user["hashed_password"]): return None - # Update last login user["last_login"] = datetime.utcnow().isoformat() await self.update_user(user["id"], {"last_login": user["last_login"]}) - return user - except Exception as e: - print(f"Error authenticating user: {e}") + except Exception: return None - async def update_user( - self, user_id: str, update_data: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """Update user data""" + async def update_user(self, user_id: str, update_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: try: registry = await self._load_user_registry() - if user_id not in registry["users"]: return None @@ -201,61 +150,35 @@ async def update_user( if await self._save_user_registry(registry): return user_record - return None - except Exception as e: - print(f"Error updating user: {e}") + except Exception: return None async def deactivate_user(self, user_id: str) -> bool: - """Deactivate a user""" - try: - result = await self.update_user(user_id, {"is_active": False}) - return result is not None - except Exception as e: - print(f"Error deactivating user: {e}") - return False + result = await self.update_user(user_id, {"is_active": False}) + return result is not None async def list_users(self, skip: int = 0, limit: int = 100) -> List[Dict[str, Any]]: - """List all users (for admin purposes)""" - try: - registry = await self._load_user_registry() - users = list(registry["users"].values()) - - users.sort(key=lambda x: x.get("created_at", ""), reverse=True) - - return users[skip : skip + limit] - except Exception as e: - print(f"Error listing users: {e}") - return [] + registry = await self._load_user_registry() + users = list(registry.get("users", {}).values()) + users.sort(key=lambda x: x.get("created_at", ""), reverse=True) + return users[skip : skip + limit] async def get_all_users(self) -> List[Dict[str, Any]]: - """Get all users (for admin purposes)""" - try: - registry = await self._load_user_registry() - users = list(registry["users"].values()) - - users.sort(key=lambda x: x.get("created_at", ""), reverse=True) - - return users - except Exception as e: - print(f"Error getting all users: {e}") - return [] + registry = await self._load_user_registry() + users = list(registry.get("users", {}).values()) + users.sort(key=lambda x: x.get("created_at", ""), reverse=True) + return users async def get_user_count(self) -> int: - """Get total number of users""" - try: - registry = await self._load_user_registry() - return len(registry["users"]) - except Exception as e: - print(f"Error getting user count: {e}") - return 0 + registry = await self._load_user_registry() + return len(registry.get("users", {})) -# Global user database service instance -user_db_service = ( - UserDatabaseService() if settings.AZURE_STORAGE_ACCOUNT_NAME else None -) +# Global instance +try: + user_db_service: Optional[UserDatabaseService] = UserDatabaseService() +except Exception: + user_db_service = None -# Alias for backward compatibility user_db = user_db_service diff --git a/backend/api/sr/router.py b/backend/api/sr/router.py index 1db2b27e..e36b095d 100644 --- a/backend/api/sr/router.py +++ b/backend/api/sr/router.py @@ -22,34 +22,12 @@ from ..core.config import settings from ..core.security import get_current_active_user -from ..services.user_db import user_db as user_db_service +from ..services.user_db import user_db_service from ..services.sr_db_service import srdb_service from ..core.cit_utils import load_sr_and_check router = APIRouter() -# Helper to get database connection string -def _get_db_conn_str() -> Optional[str]: - """ - Get database connection string for PostgreSQL. - - If POSTGRES_URI is set, returns it directly (local development). - If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER), - returns None to signal that connect_postgres() should use Entra ID authentication. - """ - if settings.POSTGRES_URI: - return settings.POSTGRES_URI - - # If Entra ID config is available, return None to let connect_postgres use token auth - if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER: - return None - - # No configuration available - raise ValueError( - "PostgreSQL not configured. Set POSTGRES_URI for local development, " - "or POSTGRES_HOST, POSTGRES_DATABASE, and POSTGRES_USER for Entra ID authentication." - ) - class SystematicReviewCreate(BaseModel): name: str description: Optional[str] = None @@ -98,7 +76,6 @@ async def create_systematic_review( One of criteria_file or criteria_yaml may be provided. If both are provided, criteria_file takes precedence. The created SR is stored in PostgreSQL and the creating user is added as the first member. """ - db_conn_str = _get_db_conn_str() if not name or not name.strip(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name is required") @@ -134,7 +111,6 @@ async def create_systematic_review( try: sr_doc = await run_in_threadpool( srdb_service.create_systematic_review, - db_conn_str, name, description, criteria_str, @@ -187,9 +163,8 @@ async def add_user_to_systematic_review( The endpoint checks that the requester is a member of the SR. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -203,7 +178,7 @@ async def add_user_to_systematic_review( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email") try: - res = await run_in_threadpool(srdb_service.add_user, db_conn_str, sr_id, target_user_id, current_user.get("id")) + res = await run_in_threadpool(srdb_service.add_user, sr_id, target_user_id, current_user.get("id")) except HTTPException: raise except Exception as e: @@ -225,10 +200,8 @@ async def remove_user_from_systematic_review( The endpoint checks that the requester is a member of the SR (or owner). The owner cannot be removed via this endpoint. """ - - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -246,7 +219,7 @@ async def remove_user_from_systematic_review( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove the owner from the systematic review") try: - res = await run_in_threadpool(srdb_service.remove_user, db_conn_str, sr_id, target_user_id, current_user.get("id")) + res = await run_in_threadpool(srdb_service.remove_user, sr_id, target_user_id, current_user.get("id")) except HTTPException: raise except Exception as e: @@ -263,12 +236,11 @@ async def list_systematic_reviews_for_user( List all systematic reviews the current user has access to (is a member of). Hidden/deleted SRs (visible == False) are excluded. """ - db_conn_str = _get_db_conn_str() user_id = current_user.get("email") results = [] try: - docs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, db_conn_str, user_id) + docs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, user_id) except HTTPException: raise except Exception as e: @@ -301,9 +273,8 @@ async def get_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depen Get a single systematic review by id. User must be a member to view. """ - db_conn_str = _get_db_conn_str() try: - doc, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -336,9 +307,8 @@ async def get_systematic_review_criteria_parsed( Returns an empty dict if no parsed criteria are available. """ - db_conn_str = _get_db_conn_str() try: - doc, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -363,9 +333,8 @@ async def update_systematic_review_criteria( The parsed criteria (dict) and the raw YAML are both saved to the SR document. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -402,7 +371,7 @@ async def update_systematic_review_criteria( # perform update try: - doc = await run_in_threadpool(srdb_service.update_criteria, db_conn_str, sr_id, criteria_obj, criteria_str, current_user.get("id")) + doc = await run_in_threadpool(srdb_service.update_criteria, sr_id, criteria_obj, criteria_str, current_user.get("id")) except HTTPException: raise except Exception as e: @@ -432,9 +401,8 @@ async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = De Only the owner may delete a systematic review. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -445,7 +413,7 @@ async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = De raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may delete this systematic review") try: - res = await run_in_threadpool(srdb_service.soft_delete_systematic_review, db_conn_str, sr_id, requester_id) + res = await run_in_threadpool(srdb_service.soft_delete_systematic_review, sr_id, requester_id) except HTTPException: raise except Exception as e: @@ -462,9 +430,8 @@ async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Only the owner may undelete a systematic review. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False, require_visible=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False) except HTTPException: raise except Exception as e: @@ -475,7 +442,7 @@ async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may undelete this systematic review") try: - res = await run_in_threadpool(srdb_service.undelete_systematic_review, db_conn_str, sr_id, requester_id) + res = await run_in_threadpool(srdb_service.undelete_systematic_review, sr_id, requester_id) except HTTPException: raise except Exception as e: @@ -495,9 +462,8 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] Only the owner may perform a hard delete. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False, require_visible=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False) except HTTPException: raise except Exception as e: @@ -526,7 +492,7 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] cleanup_result = {"status": "cleanup_import_failed", "error": str(e)} try: - res = await run_in_threadpool(srdb_service.hard_delete_systematic_review, db_conn_str, sr_id, requester_id) + res = await run_in_threadpool(srdb_service.hard_delete_systematic_review, sr_id, requester_id) deleted_count = res.get("deleted_count") if not deleted_count: # If backend reported zero deletions, raise NotFound to match prior behavior diff --git a/backend/configs/models.yaml b/backend/configs/models.yaml new file mode 100644 index 00000000..75953bc9 --- /dev/null +++ b/backend/configs/models.yaml @@ -0,0 +1,7 @@ +GPT-4.1-Mini: + deployment: gpt-4.1-mini + api_version: 2025-01-01-preview + +GPT-5-Mini: + deployment: gpt-5-mini + api_version: 2025-04-01-preview diff --git a/backend/deploy.sh b/backend/deploy.sh index 6de15442..1f769e6d 100755 --- a/backend/deploy.sh +++ b/backend/deploy.sh @@ -95,6 +95,7 @@ fi # Create necessary directories echo -e "${BLUE}📁 Creating volume directories...${NC}" mkdir -p volumes/{postgres-cits} +mkdir -p uploads/users echo -e "${GREEN}🚀 Starting services...${NC}" diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index 76b9fb6a..3cca1cbd 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -11,7 +11,6 @@ services: - "8000:8000" environment: - GROBID_SERVICE_URL=http://grobid-service:8070 - - POSTGRES_URI=postgres://admin:password@pgdb-service:5432/postgres env_file: - .env depends_on: @@ -19,6 +18,7 @@ services: - pgdb-service volumes: - ./api:/app/api + - ./configs:/app/configs - ./uploads:/app/uploads healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] @@ -38,10 +38,10 @@ services: restart: unless-stopped healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8070/api/isalive"] - interval: 30s + interval: 10s timeout: 10s retries: 5 - start_period: 30s + start_period: 120s # ============================================================================= # POSTGRESQL - Database (Citations & Systematic Reviews) diff --git a/backend/main.py b/backend/main.py index e89158e2..4173c9f1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -31,17 +31,10 @@ async def startup_event(): print("📚 Initializing systematic review database...", flush=True) # Ensure systematic review table exists in PostgreSQL try: - # Check if Entra ID or POSTGRES_URI is configured - has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER - has_uri_config = settings.POSTGRES_URI - - if has_entra_config or has_uri_config: - # Pass connection string if available, otherwise None for Entra ID auth - db_conn_str = settings.POSTGRES_URI if has_uri_config else None - await run_in_threadpool(srdb_service.ensure_table_exists, db_conn_str) - print("✓ Systematic review table initialized", flush=True) - else: - print("⚠️ PostgreSQL not configured - skipping SR table initialization", flush=True) + # POSTGRES_URI is deprecated; the DB connection is handled by postgres_server + # using POSTGRES_MODE/POSTGRES_* settings. + await run_in_threadpool(srdb_service.ensure_table_exists) + print("✓ Systematic review table initialized", flush=True) except Exception as e: print(f"⚠️ Failed to ensure SR table exists: {e}", flush=True) print("🎯 CAN-SR Backend ready!", flush=True) @@ -88,8 +81,9 @@ async def health_check(): "status": "ok", "service": settings.PROJECT_NAME, "version": settings.VERSION, - "storage_type": settings.STORAGE_TYPE, - "azure_storage_configured": bool(settings.AZURE_STORAGE_CONNECTION_STRING), + "azure_openai_mode": settings.AZURE_OPENAI_MODE, + "postgres_mode": settings.POSTGRES_MODE, + "storage_mode": settings.STORAGE_MODE, "azure_openai_configured": azure_openai_client.is_configured(), "default_chat_model": settings.DEFAULT_CHAT_MODEL, "available_models": azure_openai_client.get_available_models(), diff --git a/backend/requirements.txt b/backend/requirements.txt index 517a0efa..8b8daf90 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -3,6 +3,7 @@ fastapi>=0.104.1 uvicorn>=0.23.2 pydantic>=2.7.0 pydantic-settings>=2.3.0 +PyYAML==6.0.2 email-validator>=2.0.0 python-multipart==0.0.22 python-dotenv>=1.0.1 diff --git a/frontend/components/can-sr/stacking-card.tsx b/frontend/components/can-sr/stacking-card.tsx index 3cf78543..1a73eac6 100644 --- a/frontend/components/can-sr/stacking-card.tsx +++ b/frontend/components/can-sr/stacking-card.tsx @@ -36,10 +36,10 @@ export default function StackingCard({ title, description, href, className }: St {open ? 'Minimize' : 'Expand'} */} - - + + Open - + @@ -58,5 +58,5 @@ export default function StackingCard({ title, description, href, className }: St - ) + ); } diff --git a/frontend/public/images/backgrounds/homepage.jpg b/frontend/public/images/backgrounds/homepage.jpg index 7f9b249f..3fa1aac2 100644 Binary files a/frontend/public/images/backgrounds/homepage.jpg and b/frontend/public/images/backgrounds/homepage.jpg differ diff --git a/frontend/public/images/backgrounds/homepage_hc.jpg b/frontend/public/images/backgrounds/homepage_hc.jpg new file mode 100644 index 00000000..7f9b249f Binary files /dev/null and b/frontend/public/images/backgrounds/homepage_hc.jpg differ diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index e7ff3a26..74f65609 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -15,7 +15,7 @@ "moduleResolution": "bundler", "resolveJsonModule": true, "isolatedModules": true, - "jsx": "react-jsx", + "jsx": "preserve", "incremental": true, "plugins": [ {