diff --git a/.gitignore b/.gitignore index fe155de0..df559c58 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,7 @@ frontend/app/can-sr/setup/test.yaml backend/api/citations/test.ipynb file1.csv */criteria_config_measles_updated.yaml -AGENTS_ROADMAP.md \ No newline at end of file +AGENTS_ROADMAP.md + +*/logfile +logfile \ No newline at end of file diff --git a/backend/.dockerignore b/backend/.dockerignore index 5813e6e2..ea71973c 100644 --- a/backend/.dockerignore +++ b/backend/.dockerignore @@ -23,4 +23,5 @@ dmypy.json .DS_Store deploy.sh *.sh +!entrypoint.sh *.pem \ No newline at end of file diff --git a/backend/Dockerfile b/backend/Dockerfile index 610e9252..6a01e65f 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -36,14 +36,26 @@ RUN mkdir -p uploads # Create non-root user for security RUN useradd -m -u 1001 appuser && chown -R appuser:appuser /app -USER appuser + + +COPY sshd_config /etc/ssh/ +COPY entrypoint.sh /entrypoint.sh + +RUN apt-get update \ + && apt-get install -y --no-install-recommends dialog \ + && apt-get install -y --no-install-recommends openssh-server \ + && echo "root:Docker!" | chpasswd \ + && chmod u+x /entrypoint.sh + +USER root # Expose port -EXPOSE 8000 +EXPOSE 8000 2222 # Health check HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" # Run the application +ENTRYPOINT ["/entrypoint.sh"] CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/backend/api/citations/router.py b/backend/api/citations/router.py index 5b6e8f12..247a2859 100644 --- a/backend/api/citations/router.py +++ b/backend/api/citations/router.py @@ -27,7 +27,6 @@ from fastapi.concurrency import run_in_threadpool from pydantic import BaseModel - from ..services.sr_db_service import srdb_service from ..core.security import get_current_active_user @@ -38,6 +37,34 @@ router = APIRouter() +def _get_db_conn_str() -> Optional[str]: + """ + Get database connection string for PostgreSQL. + + If POSTGRES_URI is set, returns it directly (local development). + If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER), + returns None to signal that connect_postgres() should use Entra ID authentication. + """ + if settings.POSTGRES_URI: + return settings.POSTGRES_URI + + # If Entra ID config is available, return None to let connect_postgres use token auth + if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER: + return None + + # No configuration available - return None, let downstream handle the error + return None + + +def _is_postgres_configured() -> bool: + """ + Check if PostgreSQL is configured via Entra ID env vars or connection string. + """ + has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER + has_uri_config = settings.POSTGRES_URI + return bool(has_entra_config or has_uri_config) + + class UploadResult(BaseModel): sr_id: str table_name: str @@ -75,7 +102,7 @@ async def upload_screening_csv( - The SR must exist and the user must be a member of the SR (or owner). """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, _ = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) except HTTPException: @@ -83,13 +110,13 @@ async def upload_screening_csv( except Exception as e: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") - # Shared DB connection string - db_conn = settings.POSTGRES_URI - if not db_conn: + # Check admin DSN (use centralized settings) - need either Entra ID config or POSTGRES_URI + if not _is_postgres_configured(): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="PostgreSQL connection not configured. Set POSTGRES_URI in configuration/environment.", + detail="Postgres not configured. Set POSTGRES_HOST/DATABASE/USER for Entra ID auth, or POSTGRES_URI for local dev.", ) + admin_dsn = _get_db_conn_str() # Read CSV content include_columns = None @@ -114,14 +141,14 @@ async def upload_screening_csv( try: old = (sr.get("screening_db") or {}).get("table_name") if old: - await run_in_threadpool(cits_dp_service.drop_table, db_conn, old) + await run_in_threadpool(cits_dp_service.drop_table, admin_dsn, old) except Exception: # best-effort only pass # Create table and insert rows in threadpool try: - inserted = await run_in_threadpool(_create_table_and_insert_sync, db_conn, table_name, include_columns, normalized_rows) + inserted = await run_in_threadpool(_create_table_and_insert_sync, admin_dsn, table_name, include_columns, normalized_rows) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -131,7 +158,7 @@ async def upload_screening_csv( try: screening_info = { "screening_db": { - "connection_string": db_conn, + "connection_string": admin_dsn, "table_name": table_name, "created_at": datetime.utcnow().isoformat(), "rows": inserted, @@ -141,7 +168,7 @@ async def upload_screening_csv( # Update SR document with screening DB info using PostgreSQL await run_in_threadpool( srdb_service.update_screening_db_info, - settings.POSTGRES_URI, + _get_db_conn_str(), sr_id, screening_info["screening_db"] ) @@ -171,7 +198,7 @@ async def list_citation_ids( Returns a simple list of integers (the 'id' primary key from the citations table). """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) except HTTPException: @@ -210,7 +237,7 @@ async def get_citation_by_id( Returns: a JSON object representing the citation row (keys are DB column names). """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) except HTTPException: @@ -263,7 +290,7 @@ async def build_combined_citation( the format ": \\n" for each included column, in the order provided. """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) except HTTPException: @@ -318,7 +345,7 @@ async def upload_citation_fulltext( to the storage path (container/blob). """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) except HTTPException: @@ -414,7 +441,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An - Caller must be the SR owner. """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) except HTTPException: @@ -429,9 +456,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An if not screening: return {"status": "no_screening_db", "message": "No screening table configured for this SR", "deleted_table": False, "deleted_files": 0} - db_conn = screening.get("connection_string") - if not db_conn: - return {"status": "no_screening_db", "message": "Incomplete screening DB metadata", "deleted_table": False, "deleted_files": 0} + db_conn = None table_name = screening.get("table_name") if not table_name: @@ -521,7 +546,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An try: await run_in_threadpool( srdb_service.clear_screening_db_info, - settings.POSTGRES_URI, + _get_db_conn_str(), sr_id ) except Exception: @@ -557,7 +582,7 @@ async def export_citations_csv( Content-Disposition. """ - db_conn_str = settings.POSTGRES_URI + db_conn_str = _get_db_conn_str() try: sr, screening, db_conn = await load_sr_and_check( sr_id, current_user, db_conn_str, srdb_service diff --git a/backend/api/core/cit_utils.py b/backend/api/core/cit_utils.py index 4cfd3e56..12442fa8 100644 --- a/backend/api/core/cit_utils.py +++ b/backend/api/core/cit_utils.py @@ -13,10 +13,22 @@ from fastapi import HTTPException, status from fastapi.concurrency import run_in_threadpool +from .config import settings + + +def _is_postgres_configured(db_conn_str: Optional[str] = None) -> bool: + """ + Check if PostgreSQL is configured via Entra ID env vars or connection string. + """ + has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER + has_uri_config = db_conn_str or settings.POSTGRES_URI + return bool(has_entra_config or has_uri_config) + + async def load_sr_and_check( sr_id: str, current_user: Dict[str, Any], - db_conn_str: str, + db_conn_str: Optional[str], srdb_service, require_screening: bool = True, require_visible: bool = True, @@ -27,7 +39,7 @@ async def load_sr_and_check( Args: sr_id: SR id string current_user: current user dict (must contain "id" and "email") - db_conn_str: PostgreSQL connection string + db_conn_str: PostgreSQL connection string (can be None if using Entra ID auth) srdb_service: SR DB service instance (must implement get_systematic_review and user_has_sr_permission) require_screening: if True, also ensure the SR has a configured screening_db and return its connection string require_visible: if True, require the SR 'visible' flag to be True; set False for endpoints like hard-delete @@ -37,18 +49,6 @@ async def load_sr_and_check( Raises HTTPException with appropriate status codes on failure so routers can just propagate. """ - # ensure DB helper present and call it - if not db_conn_str: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Server misconfiguration: PostgreSQL connection not available", - ) - try: - await run_in_threadpool(srdb_service.ensure_db_available, db_conn_str) - except HTTPException: - raise - except Exception as e: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(e)) # fetch SR try: @@ -81,8 +81,6 @@ async def load_sr_and_check( if require_screening: if not screening: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review") - db_conn = screening.get("connection_string") - if not db_conn: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Screening DB connection info missing") + db_conn = None return sr, screening, db_conn diff --git a/backend/api/core/config.py b/backend/api/core/config.py index 696a03a2..3fb7d32a 100644 --- a/backend/api/core/config.py +++ b/backend/api/core/config.py @@ -28,6 +28,9 @@ class Settings(BaseSettings): # Storage settings STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "azure") + AZURE_STORAGE_ACCOUNT_NAME: Optional[str] = os.getenv( + "AZURE_STORAGE_ACCOUNT_NAME" + ) AZURE_STORAGE_CONNECTION_STRING: Optional[str] = os.getenv( "AZURE_STORAGE_CONNECTION_STRING" ) @@ -108,8 +111,16 @@ def convert_max_file_size(cls, v): DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" # Database and external system environment variables - # Postgres DSN used for systematic reviews and screening databases - POSTGRES_URI: str = os.getenv("POSTGRES_URI") + # Postgres settings for Entra ID authentication + POSTGRES_HOST: Optional[str] = os.getenv("POSTGRES_HOST") + POSTGRES_DATABASE: Optional[str] = os.getenv("POSTGRES_DATABASE") + POSTGRES_USER: Optional[str] = os.getenv("POSTGRES_USER") # Entra ID user (e.g., user@tenant.onmicrosoft.com) + POSTGRES_PORT: int = int(os.getenv("POSTGRES_PORT", "5432")) + POSTGRES_SSL_MODE: Optional[str] = os.getenv("POSTGRES_SSL_MODE") + POSTGRES_PASSWORD: Optional[str] = os.getenv("POSTGRES_PASSWORD") + AZURE_DB: bool = os.getenv("AZURE_DB", "false").lower() == "true" + # Legacy: Postgres DSN used for systematic reviews and screening databases (fallback) + POSTGRES_URI: Optional[str] = os.getenv("POSTGRES_URI") # Databricks settings DATABRICKS_INSTANCE: str = os.getenv("DATABRICKS_INSTANCE") @@ -124,6 +135,9 @@ def convert_max_file_size(cls, v): REDIRECT_URI: str = os.getenv("REDIRECT_URI") SSO_LOGIN_URL: str = os.getenv("SSO_LOGIN_URL") + # Entra + USE_ENTRA_AUTH: bool = os.getenv("USE_ENTRA_AUTH", "false").lower() == "true" + class Config: case_sensitive = True env_file = ".env" diff --git a/backend/api/services/azure_openai_client.py b/backend/api/services/azure_openai_client.py index 73791d40..dd89e685 100644 --- a/backend/api/services/azure_openai_client.py +++ b/backend/api/services/azure_openai_client.py @@ -1,10 +1,32 @@ """Azure OpenAI client service for chat completions""" +import time from typing import Dict, List, Any, Optional +from azure.identity import DefaultAzureCredential, get_bearer_token_provider from openai import AzureOpenAI from ..core.config import settings +# Token cache TTL in seconds (9 minutes) +TOKEN_CACHE_TTL = 9 * 60 + + +class CachedTokenProvider: + """Simple in-memory token cache wrapper""" + + def __init__(self, token_provider): + self._token_provider = token_provider + self._cached_token: Optional[str] = None + self._token_expiry: float = 0 + + def __call__(self) -> str: + """Return cached token or fetch a new one if expired""" + current_time = time.time() + if self._cached_token is None or current_time >= self._token_expiry: + self._cached_token = self._token_provider() + self._token_expiry = current_time + TOKEN_CACHE_TTL + return self._cached_token + class AzureOpenAIClient: """Client for Azure OpenAI chat completions""" @@ -12,18 +34,27 @@ class AzureOpenAIClient: def __init__(self): self.default_model = settings.DEFAULT_CHAT_MODEL + # Create token provider for Azure OpenAI using DefaultAzureCredential + # Wrapped with caching to avoid fetching a new token on every request + if not settings.AZURE_OPENAI_API_KEY and not settings.USE_ENTRA_AUTH: + raise ValueError("Azure OpenAI API key or Entra auth must be configured") + + if settings.USE_ENTRA_AUTH: + self._credential = DefaultAzureCredential() + self._token_provider = CachedTokenProvider( + get_bearer_token_provider( + self._credential, "https://cognitiveservices.azure.com/.default" + ) + ) + self.model_configs = { "gpt-4.1-mini": { - "api_key": settings.AZURE_OPENAI_GPT41_MINI_API_KEY - or settings.AZURE_OPENAI_API_KEY, "endpoint": settings.AZURE_OPENAI_GPT41_MINI_ENDPOINT or settings.AZURE_OPENAI_ENDPOINT, "deployment": settings.AZURE_OPENAI_GPT41_MINI_DEPLOYMENT, "api_version": settings.AZURE_OPENAI_GPT41_MINI_API_VERSION, }, "gpt-5-mini": { - "api_key": settings.AZURE_OPENAI_GPT5_MINI_API_KEY - or settings.AZURE_OPENAI_API_KEY, "endpoint": settings.AZURE_OPENAI_GPT5_MINI_ENDPOINT or settings.AZURE_OPENAI_ENDPOINT, "deployment": settings.AZURE_OPENAI_GPT5_MINI_DEPLOYMENT, @@ -43,16 +74,22 @@ def _get_official_client(self, model: str) -> AzureOpenAI: """Get official Azure OpenAI client instance""" if model not in self._official_clients: config = self._get_model_config(model) - if not config.get("api_key"): + if not config.get("endpoint"): raise ValueError( - f"Azure OpenAI API key not configured for model {model}" + f"Azure OpenAI endpoint not configured for model {model}" ) - self._official_clients[model] = AzureOpenAI( - api_key=config["api_key"], - azure_endpoint=config["endpoint"], - api_version=config["api_version"], - ) + azure_openai_kwargs = { + "azure_endpoint": config["endpoint"], + "api_version": config["api_version"], + } + if settings.USE_ENTRA_AUTH: + azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider + + if settings.AZURE_OPENAI_API_KEY: + azure_openai_kwargs["api_key"] = settings.AZURE_OPENAI_API_KEY + + self._official_clients[model] = AzureOpenAI(**azure_openai_kwargs) return self._official_clients[model] @@ -99,7 +136,6 @@ async def chat_completion( try: client = self._get_official_client(model) - request_kwargs = { "model": deployment, "messages": messages, @@ -301,7 +337,7 @@ def get_available_models(self) -> List[str]: return [ model for model, config in self.model_configs.items() - if config.get("api_key") and config.get("endpoint") + if config.get("endpoint") ] def is_configured(self) -> bool: diff --git a/backend/api/services/cit_db_service.py b/backend/api/services/cit_db_service.py index 1048c2db..b130b688 100644 --- a/backend/api/services/cit_db_service.py +++ b/backend/api/services/cit_db_service.py @@ -13,6 +13,8 @@ can surface a 503 with an actionable message. """ from typing import Any, Dict, List, Optional +import psycopg2 +import psycopg2.extras import json import re import os @@ -26,14 +28,7 @@ except Exception: settings = None - -def _ensure_psycopg2(): - try: - import psycopg2 - import psycopg2.extras # noqa: F401 - return psycopg2 - except Exception: - raise RuntimeError("psycopg2 is not installed on the server environment") +from .postgres_auth import postgres_server # ----------------------- @@ -138,17 +133,7 @@ def __init__(self): # ----------------------- # Low level connection helpers # ----------------------- - def _ensure_psycopg2(self): - return _ensure_psycopg2() - def _connect(self, db_conn_str: str): - """ - Connect and return a psycopg2 connection. Raises RuntimeError if psycopg2 missing. - Caller is responsible for closing the connection. - """ - psycopg2 = self._ensure_psycopg2() - conn = psycopg2.connect(db_conn_str) - return conn # ----------------------- # Generic column ops @@ -162,7 +147,7 @@ def create_column(self, db_conn_str: str, col: str, col_type: str, table_name: s table_name = _validate_ident(table_name, kind="table_name") conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() try: cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" {col_type}') @@ -173,20 +158,10 @@ def create_column(self, db_conn_str: str, col: str, col_type: str, table_name: s except Exception: pass conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + finally: if conn: - try: - conn.close() - except Exception: - pass + pass def update_jsonb_column( self, @@ -202,7 +177,7 @@ def update_jsonb_column( table_name = _validate_ident(table_name, kind="table_name") conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() try: cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" JSONB') @@ -214,21 +189,11 @@ def update_jsonb_column( cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (json.dumps(data), int(citation_id))) rows = cur.rowcount conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return rows or 0 finally: if conn: - try: - conn.close() - except Exception: - pass + pass def update_text_column( self, @@ -244,7 +209,7 @@ def update_text_column( table_name = _validate_ident(table_name, kind="table_name") conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() try: cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" TEXT') @@ -256,21 +221,11 @@ def update_text_column( cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (text_value, int(citation_id))) rows = cur.rowcount conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return rows or 0 finally: if conn: - try: - conn.close() - except Exception: - pass + pass # ----------------------- # Citation row helpers @@ -286,7 +241,7 @@ def dump_citations_csv(self, db_conn_str: str, table_name: str = "citations") -> table_name = _validate_ident(table_name, kind="table_name") conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() buf = io.StringIO() @@ -297,32 +252,21 @@ def dump_citations_csv(self, db_conn_str: str, table_name: str = "citations") -> ) csv_text = buf.getvalue() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return csv_text.encode("utf-8") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def get_citation_by_id(self, db_conn_str: str, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]: """ Return a dict mapping column -> value for the citation row, or None. """ table_name = _validate_ident(table_name, kind="table_name") - psycopg2 = self._ensure_psycopg2() conn = None try: - conn = psycopg2.connect(db_conn_str) + conn = postgres_server.conn try: cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) except Exception: @@ -330,45 +274,26 @@ def get_citation_by_id(self, db_conn_str: str, citation_id: int, table_name: str cur.execute(f'SELECT * FROM "{table_name}" WHERE id = %s', (citation_id,)) row = cur.fetchone() if row is None: - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass return None if isinstance(row, dict): result = row else: cols = [desc[0] for desc in cur.description] result = {cols[i]: row[i] for i in range(len(cols))} - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return result finally: if conn: - try: - conn.close() - except Exception: - pass + pass def list_citation_ids(self, db_conn_str: str, filter_step=None, table_name: str = "citations") -> List[int]: """ Return list of integer primary keys (id) from citations table ordered by id. """ table_name = _validate_ident(table_name, kind="table_name") - psycopg2 = self._ensure_psycopg2() conn = None try: - conn = psycopg2.connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() if filter_step is not None: @@ -402,49 +327,28 @@ def list_citation_ids(self, db_conn_str: str, filter_step=None, table_name: str cur.execute(query) rows = cur.fetchall() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return [int(r[0]) for r in rows] finally: if conn: - try: - conn.close() - except Exception: - pass + pass def list_fulltext_urls(self, db_conn_str: str, table_name: str = "citations") -> List[str]: """ Return list of fulltext_url values (non-null) from citations table. """ table_name = _validate_ident(table_name, kind="table_name") - psycopg2 = self._ensure_psycopg2() conn = None try: - conn = psycopg2.connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() cur.execute(f'SELECT fulltext_url FROM "{table_name}" WHERE fulltext_url IS NOT NULL') rows = cur.fetchall() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return [r[0] for r in rows if r and r[0]] finally: if conn: - try: - conn.close() - except Exception: - pass + pass def update_citation_fulltext(self, db_conn_str: str, citation_id: int, fulltext_path: str) -> int: """ @@ -474,14 +378,14 @@ def attach_fulltext( md5 = hashlib.md5(file_bytes).hexdigest() if file_bytes is not None else "" # update both columns in one statement - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() cur.execute(f'UPDATE "{table_name}" SET "fulltext_url" = %s WHERE id = %s', (azure_path, int(citation_id))) rows = cur.rowcount conn.commit() - cur.close() - conn.close() + + return rows # ----------------------- @@ -492,10 +396,9 @@ def get_column_value(self, db_conn_str: str, citation_id: int, column: str, tabl Return the value stored in `column` for the citation row (or None). """ table_name = _validate_ident(table_name, kind="table_name") - psycopg2 = self._ensure_psycopg2() conn = None try: - conn = psycopg2.connect(db_conn_str) + conn = postgres_server.conn try: cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) except Exception: @@ -503,35 +406,17 @@ def get_column_value(self, db_conn_str: str, citation_id: int, column: str, tabl cur.execute(f'SELECT "{column}" FROM "{table_name}" WHERE id = %s', (citation_id,)) row = cur.fetchone() if not row: - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass return None # row may be dict or tuple if isinstance(row, dict): val = list(row.values())[0] if row else None else: val = row[0] if row and len(row) > 0 else None - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return val finally: if conn: - try: - conn.close() - except Exception: - pass + pass def set_column_value(self, db_conn_str: str, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int: """ @@ -549,21 +434,14 @@ def drop_table(self, db_conn_str: str, table_name: str, cascade: bool = True) -> table_name = _validate_ident(table_name, kind="table_name") conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn conn.autocommit = True cur = conn.cursor() cas = " CASCADE" if cascade else "" cur.execute(f'DROP TABLE IF EXISTS "{table_name}"{cas}') - try: - cur.close() - except Exception: - pass finally: if conn: - try: - conn.close() - except Exception: - pass + pass def create_table_and_insert_sync( self, @@ -578,11 +456,9 @@ def create_table_and_insert_sync( is per-upload (e.g. sr___citations) inside the shared DB. """ table_name = _validate_ident(table_name, kind="table_name") - - psycopg2 = self._ensure_psycopg2() conn = None try: - conn = psycopg2.connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() # Create table @@ -632,21 +508,11 @@ def _row_has_data(row: dict) -> bool: inserted = len(values) conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return inserted finally: if conn: - try: - conn.close() - except Exception: - pass + pass # NOTE: legacy per-database helpers (drop_database, create_db_and_table_sync) were # intentionally removed in favor of per-upload tables in a shared database. diff --git a/backend/api/services/postgres_auth.py b/backend/api/services/postgres_auth.py new file mode 100644 index 00000000..9334c520 --- /dev/null +++ b/backend/api/services/postgres_auth.py @@ -0,0 +1,105 @@ +""" +PostgreSQL authentication helper using Azure Entra ID (DefaultAzureCredential). + +This module provides a centralized way to connect to Azure Database for PostgreSQL +using Entra ID authentication, with fallback to connection string for local development. +""" + +from typing import Optional + +import psycopg2 +from ..core.config import settings +import logging +import datetime +from azure.identity import DefaultAzureCredential + +logger = logging.getLogger(__name__) + + +class PostgresServer: + """Manages a persistent PostgreSQL connection with automatic Azure token refresh.""" + + _AZURE_POSTGRES_SCOPE = "https://ossrdbms-aad.database.windows.net/.default" + _TOKEN_REFRESH_BUFFER_SECONDS = 60 + + def __init__(self): + self._verify_config() + self._credential = DefaultAzureCredential() if settings.AZURE_DB else None + self._token: Optional[str] = None + self._token_expiration: int = 0 + self._conn = None + + @property + def conn(self): + """Return an open connection, reconnecting only when necessary.""" + if self._conn is None or self._conn.closed: + print("local database") + self._conn = self._connect() + elif settings.AZURE_DB and self._is_token_expired(): + logger.info("Azure token expired — reconnecting to PostgreSQL") + print("cloud database") + self.close() + self._conn = self._connect() + print(self._conn) + return self._conn + + def close(self): + """Safely close the current connection (idempotent).""" + if self._conn and not self._conn.closed: + try: + self._conn.close() + except Exception: + logger.warning("Failed to close PostgreSQL connection", exc_info=True) + self._conn = None + + @staticmethod + def _verify_config(): + """Validate that all required PostgreSQL settings are present.""" + required = [settings.POSTGRES_HOST, settings.POSTGRES_DATABASE, settings.POSTGRES_USER] + if not all(required): + raise RuntimeError("POSTGRES_HOST, POSTGRES_DATABASE, and POSTGRES_USER are required") + if not settings.AZURE_DB and not settings.POSTGRES_PASSWORD: + raise RuntimeError("POSTGRES_PASSWORD is required when AZURE_DB is False") + + def _is_token_expired(self) -> bool: + """Check whether the cached Azure token needs refreshing.""" + now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) + return not self._token or now >= self._token_expiration + + def _refresh_azure_token(self) -> str: + """Return a valid Azure token, fetching a new one only if expired.""" + if self._is_token_expired(): + logger.info("Fetching fresh Azure PostgreSQL token") + token = self._credential.get_token(self._AZURE_POSTGRES_SCOPE) + self._token = token.token + self._token_expiration = token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS + return self._token + + def _build_connect_kwargs(self) -> dict: + """Assemble psycopg2.connect() keyword arguments from settings.""" + kwargs = { + "host": settings.POSTGRES_HOST, + "database": settings.POSTGRES_DATABASE, + "user": settings.POSTGRES_USER, + "port": settings.POSTGRES_PORT, + } + if settings.POSTGRES_SSL_MODE: + kwargs["sslmode"] = settings.POSTGRES_SSL_MODE + if settings.AZURE_DB: + kwargs["password"] = self._refresh_azure_token() + elif settings.POSTGRES_PASSWORD: + kwargs["password"] = settings.POSTGRES_PASSWORD + return kwargs + + def _connect(self): + """Create a new psycopg2 connection.""" + return psycopg2.connect(**self._build_connect_kwargs()) + + def __repr__(self) -> str: + status = "open" if self._conn and not self._conn.closed else "closed" + return ( + f"" + ) + +postgres_server = PostgresServer() \ No newline at end of file diff --git a/backend/api/services/sr_db_service.py b/backend/api/services/sr_db_service.py index 5b9d9c3c..1dfaedce 100644 --- a/backend/api/services/sr_db_service.py +++ b/backend/api/services/sr_db_service.py @@ -16,16 +16,10 @@ from fastapi import HTTPException, status -logger = logging.getLogger(__name__) - +from .postgres_auth import postgres_server +from ..core.config import settings -def _ensure_psycopg2(): - try: - import psycopg2 - import psycopg2.extras # noqa: F401 - return psycopg2 - except Exception: - raise RuntimeError("psycopg2 is not installed on the server environment") +logger = logging.getLogger(__name__) class SRDBService: @@ -33,39 +27,6 @@ def __init__(self): # Service is stateless; connection strings passed per-call pass - def _ensure_psycopg2(self): - return _ensure_psycopg2() - - def _connect(self, db_conn_str: str): - """ - Connect and return a psycopg2 connection. Raises RuntimeError if psycopg2 missing. - Caller is responsible for closing the connection. - """ - psycopg2 = self._ensure_psycopg2() - conn = psycopg2.connect(db_conn_str) - return conn - - def ensure_db_available(self, db_conn_str: Optional[str] = None) -> None: - """ - Raise an HTTPException (503) if the PostgreSQL connection is not available. - Routers call this to provide consistent error messages when Postgres is not configured. - """ - if not db_conn_str: - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="PostgreSQL connection not configured. Set POSTGRES_URI environment variable.", - ) - # Try to connect to verify availability - try: - conn = self._connect(db_conn_str) - conn.close() - except Exception as e: - logger.warning(f"PostgreSQL connection failed: {e}") - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail=f"PostgreSQL connection failed: {e}", - ) - def ensure_table_exists(self, db_conn_str: str) -> None: """ Ensure the systematic_reviews table exists in PostgreSQL. @@ -74,7 +35,7 @@ def ensure_table_exists(self, db_conn_str: str) -> None: """ conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() create_table_sql = """ @@ -97,14 +58,7 @@ def ensure_table_exists(self, db_conn_str: str) -> None: cur.execute(create_table_sql) conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + logger.info("Ensured systematic_reviews table exists") except Exception as e: @@ -112,10 +66,7 @@ def ensure_table_exists(self, db_conn_str: str) -> None: raise finally: if conn: - try: - conn.close() - except Exception: - pass + pass def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[str, Any]: """ @@ -217,8 +168,6 @@ def create_systematic_review( """ Create a new SR document and insert into the table. Returns the created document. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") sr_id = str(uuid.uuid4()) now = datetime.utcnow().isoformat() @@ -229,7 +178,7 @@ def create_systematic_review( conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() insert_sql = """ @@ -276,14 +225,7 @@ def create_systematic_review( if sr_doc.get('updated_at') and isinstance(sr_doc['updated_at'], dt): sr_doc['updated_at'] = sr_doc['updated_at'].isoformat() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return sr_doc @@ -292,10 +234,7 @@ def create_systematic_review( raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create systematic review: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: """ @@ -303,8 +242,7 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ requester must be a member or owner. Returns a dict with update result metadata. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + sr = self.get_systematic_review(db_conn_str, sr_id) if not sr or not sr.get("visible", True): @@ -317,7 +255,7 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() # Get current users array @@ -343,14 +281,7 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ modified_count = cur.rowcount conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return {"matched_count": 1, "modified_count": modified_count, "added_user_id": target_user_id} @@ -361,18 +292,14 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to add user: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: """ Remove a user id from the SR's users list. Owner cannot be removed. Enforces requester permissions (must be a member or owner). """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + sr = self.get_systematic_review(db_conn_str, sr_id) if not sr or not sr.get("visible", True): @@ -388,7 +315,7 @@ def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, request conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() # Get current users array @@ -414,14 +341,7 @@ def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, request modified_count = cur.rowcount conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return {"matched_count": 1, "modified_count": modified_count, "removed_user_id": target_user_id} @@ -432,10 +352,7 @@ def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, request raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> bool: """ @@ -444,8 +361,7 @@ def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> Note: this check deliberately ignores the SR's 'visible' flag so membership checks work regardless of whether the SR is hidden/soft-deleted. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + doc = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) if not doc: @@ -462,8 +378,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, The requester must be a member or owner. Returns the updated SR document. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + sr = self.get_systematic_review(db_conn_str, sr_id) if not sr or not sr.get("visible", True): @@ -476,7 +391,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() updated_at = datetime.utcnow().isoformat() @@ -501,14 +416,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + # Return fresh doc doc = self.get_systematic_review(db_conn_str, sr_id) @@ -521,21 +429,17 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update criteria: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> List[Dict[str, Any]]: """ Return all SR documents where the user is a member (regardless of visible flag). """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() # Query using jsonb operator to check if user_email is in users array @@ -569,14 +473,7 @@ def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> results.append(doc) - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return results @@ -585,22 +482,18 @@ def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to list systematic reviews: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]: """ Return SR document by id. Returns None if not found. If ignore_visibility is False, only returns visible SRs. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() if ignore_visibility: @@ -612,14 +505,6 @@ def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: row = cur.fetchone() if not row: - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass return None cols = [desc[0] for desc in cur.description] @@ -639,14 +524,7 @@ def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: if doc.get('updated_at') and isinstance(doc['updated_at'], dt): doc['updated_at'] = doc['updated_at'].isoformat() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return doc @@ -655,18 +533,14 @@ def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: return None finally: if conn: - try: - conn.close() - except Exception: - pass + pass def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]: """ Set the visible flag on the SR. Only owner is allowed to change visibility. Returns update metadata. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + sr = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) if not sr: @@ -677,7 +551,7 @@ def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_ conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() updated_at = datetime.utcnow().isoformat() @@ -688,14 +562,7 @@ def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_ modified_count = cur.rowcount conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + return {"matched_count": 1, "modified_count": modified_count, "visible": visible} @@ -704,10 +571,7 @@ def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to set visibility: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def soft_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: """ @@ -726,8 +590,7 @@ def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_ Permanently remove the SR document. Only owner may hard delete. Returns deletion metadata (deleted_count). """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + sr = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) if not sr: @@ -738,22 +601,13 @@ def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_ conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() cur.execute("DELETE FROM systematic_reviews WHERE id = %s", (sr_id,)) deleted_count = cur.rowcount conn.commit() - - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass - + return {"deleted_count": deleted_count} except Exception as e: @@ -761,22 +615,18 @@ def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_ raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to hard-delete systematic review: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def update_screening_db_info(self, db_conn_str: str, sr_id: str, screening_db: Dict[str, Any]) -> None: """ Update the screening_db field in the SR document with screening database metadata. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() updated_at = datetime.utcnow().isoformat() @@ -786,35 +636,24 @@ def update_screening_db_info(self, db_conn_str: str, sr_id: str, screening_db: D ) conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + except Exception as e: logger.exception(f"Failed to update screening DB info: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update screening DB info: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass def clear_screening_db_info(self, db_conn_str: str, sr_id: str) -> None: """ Remove the screening_db field from the SR document. """ - if not db_conn_str: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Systematic review DB not configured") + conn = None try: - conn = self._connect(db_conn_str) + conn = postgres_server.conn cur = conn.cursor() updated_at = datetime.utcnow().isoformat() @@ -824,24 +663,14 @@ def clear_screening_db_info(self, db_conn_str: str, sr_id: str) -> None: ) conn.commit() - try: - cur.close() - except Exception: - pass - try: - conn.close() - except Exception: - pass + except Exception as e: logger.exception(f"Failed to clear screening DB info: {e}") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to clear screening DB info: {e}") finally: if conn: - try: - conn.close() - except Exception: - pass + pass # module-level instance diff --git a/backend/api/services/storage.py b/backend/api/services/storage.py index 9c34ee29..b7595416 100644 --- a/backend/api/services/storage.py +++ b/backend/api/services/storage.py @@ -6,6 +6,7 @@ from datetime import datetime, timezone from typing import Dict, List, Optional, Any +from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient from azure.core.exceptions import ResourceNotFoundError @@ -19,12 +20,19 @@ class AzureStorageService: """Service for managing user data in Azure Blob Storage""" def __init__(self): - if not settings.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError("Azure Storage connection string not configured") - - self.blob_service_client = BlobServiceClient.from_connection_string( - settings.AZURE_STORAGE_CONNECTION_STRING - ) + if not settings.AZURE_STORAGE_ACCOUNT_NAME and not settings.AZURE_STORAGE_CONNECTION_STRING: + raise ValueError("AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be configured") + + if settings.AZURE_STORAGE_ACCOUNT_NAME: + account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + credential = DefaultAzureCredential() + self.blob_service_client = BlobServiceClient( + account_url=account_url, credential=credential + ) + elif settings.AZURE_STORAGE_CONNECTION_STRING: + self.blob_service_client = BlobServiceClient.from_connection_string( + settings.AZURE_STORAGE_CONNECTION_STRING + ) self.container_name = settings.AZURE_STORAGE_CONTAINER_NAME self._ensure_container_exists() @@ -354,5 +362,5 @@ async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> boo # Global storage service instance storage_service = ( - AzureStorageService() if settings.AZURE_STORAGE_CONNECTION_STRING else None + AzureStorageService() if settings.AZURE_STORAGE_ACCOUNT_NAME else None ) diff --git a/backend/api/services/user_db.py b/backend/api/services/user_db.py index d42a9a66..00300523 100644 --- a/backend/api/services/user_db.py +++ b/backend/api/services/user_db.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Dict, List, Optional, Any +from azure.identity import DefaultAzureCredential from azure.storage.blob import BlobServiceClient from azure.core.exceptions import ResourceNotFoundError from passlib.context import CryptContext @@ -17,12 +18,20 @@ class UserDatabaseService: """Service for managing user data in Azure Blob Storage""" def __init__(self): - if not settings.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError("Azure Storage connection string not configured") + if not settings.AZURE_STORAGE_ACCOUNT_NAME and not settings.AZURE_STORAGE_CONNECTION_STRING: + raise ValueError("AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be configured") + + if settings.AZURE_STORAGE_ACCOUNT_NAME: + account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + credential = DefaultAzureCredential() + self.blob_service_client = BlobServiceClient( + account_url=account_url, credential=credential + ) + elif settings.AZURE_STORAGE_CONNECTION_STRING: + self.blob_service_client = BlobServiceClient.from_connection_string( + settings.AZURE_STORAGE_CONNECTION_STRING + ) - self.blob_service_client = BlobServiceClient.from_connection_string( - settings.AZURE_STORAGE_CONNECTION_STRING - ) self.container_name = settings.AZURE_STORAGE_CONTAINER_NAME self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -108,8 +117,8 @@ async def create_user(self, user_data: UserCreate) -> Optional[UserRead]: if await self._save_user_registry(registry): # Create user directory structure in storage from .storage import storage_service - from .dual_milvus_manager import dual_milvus_manager - from .base_knowledge_manager import base_knowledge_manager + # from .dual_milvus_manager import dual_milvus_manager + # from .base_knowledge_manager import base_knowledge_manager if storage_service: await storage_service.create_user_directory(user_id) @@ -245,7 +254,7 @@ async def get_user_count(self) -> int: # Global user database service instance user_db_service = ( - UserDatabaseService() if settings.AZURE_STORAGE_CONNECTION_STRING else None + UserDatabaseService() if settings.AZURE_STORAGE_ACCOUNT_NAME else None ) # Alias for backward compatibility diff --git a/backend/api/sr/router.py b/backend/api/sr/router.py index cf6d26ae..1db2b27e 100644 --- a/backend/api/sr/router.py +++ b/backend/api/sr/router.py @@ -29,8 +29,26 @@ router = APIRouter() # Helper to get database connection string -def _get_db_conn_str() -> str: - return settings.POSTGRES_URI +def _get_db_conn_str() -> Optional[str]: + """ + Get database connection string for PostgreSQL. + + If POSTGRES_URI is set, returns it directly (local development). + If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER), + returns None to signal that connect_postgres() should use Entra ID authentication. + """ + if settings.POSTGRES_URI: + return settings.POSTGRES_URI + + # If Entra ID config is available, return None to let connect_postgres use token auth + if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER: + return None + + # No configuration available + raise ValueError( + "PostgreSQL not configured. Set POSTGRES_URI for local development, " + "or POSTGRES_HOST, POSTGRES_DATABASE, and POSTGRES_USER for Entra ID authentication." + ) class SystematicReviewCreate(BaseModel): name: str @@ -81,7 +99,6 @@ async def create_systematic_review( The created SR is stored in PostgreSQL and the creating user is added as the first member. """ db_conn_str = _get_db_conn_str() - await run_in_threadpool(srdb_service.ensure_db_available, db_conn_str) if not name or not name.strip(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name is required") @@ -247,7 +264,6 @@ async def list_systematic_reviews_for_user( Hidden/deleted SRs (visible == False) are excluded. """ db_conn_str = _get_db_conn_str() - await run_in_threadpool(srdb_service.ensure_db_available, db_conn_str) user_id = current_user.get("email") results = [] diff --git a/backend/entrypoint.sh b/backend/entrypoint.sh new file mode 100644 index 00000000..b599c0a2 --- /dev/null +++ b/backend/entrypoint.sh @@ -0,0 +1,10 @@ +#!/bin/sh +set -e + +# Get env vars in the Dockerfile to show up in the SSH session +eval $(printenv | sed -n "s/^\([^=]\+\)=\(.*\)$/export \1=\2/p" | sed 's/"/\\\"/g' | sed '/=/s//="/' | sed 's/$/"/' >> /etc/profile) + +echo "Starting SSH ..." +service ssh start + +exec "$@" \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index bd2c8c78..e89158e2 100644 --- a/backend/main.py +++ b/backend/main.py @@ -31,11 +31,17 @@ async def startup_event(): print("📚 Initializing systematic review database...", flush=True) # Ensure systematic review table exists in PostgreSQL try: - if settings.POSTGRES_URI: - await run_in_threadpool(srdb_service.ensure_table_exists, settings.POSTGRES_URI) + # Check if Entra ID or POSTGRES_URI is configured + has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER + has_uri_config = settings.POSTGRES_URI + + if has_entra_config or has_uri_config: + # Pass connection string if available, otherwise None for Entra ID auth + db_conn_str = settings.POSTGRES_URI if has_uri_config else None + await run_in_threadpool(srdb_service.ensure_table_exists, db_conn_str) print("✓ Systematic review table initialized", flush=True) else: - print("⚠️ POSTGRES_URI not configured - skipping SR table initialization", flush=True) + print("⚠️ PostgreSQL not configured - skipping SR table initialization", flush=True) except Exception as e: print(f"⚠️ Failed to ensure SR table exists: {e}", flush=True) print("🎯 CAN-SR Backend ready!", flush=True) diff --git a/backend/sshd_config b/backend/sshd_config new file mode 100644 index 00000000..9c224d52 --- /dev/null +++ b/backend/sshd_config @@ -0,0 +1,12 @@ +Port 2222 +ListenAddress 0.0.0.0 +LoginGraceTime 180 +X11Forwarding yes +Ciphers aes128-cbc,3des-cbc,aes256-cbc,aes128-ctr,aes192-ctr,aes256-ctr +MACs hmac-sha1,hmac-sha1-96 +StrictModes yes +SyslogFacility DAEMON +PasswordAuthentication yes +PermitEmptyPasswords no +PermitRootLogin yes +Subsystem sftp internal-sftp \ No newline at end of file diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 7d59e395..03f48d16 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -16,7 +16,7 @@ RUN npm ci FROM base AS build # Public build-time environment variables -ARG NEXT_PUBLIC_BACKEND_URL="https://grep-exp-can-sr-api-dv.phac-aspc.gc.ca" +ARG NEXT_PUBLIC_BACKEND_URL="https://was-sdse-spib-hail-api-dt.azurewebsites.net" ENV NEXT_PUBLIC_BACKEND_URL=${NEXT_PUBLIC_BACKEND_URL} @@ -31,8 +31,8 @@ FROM base AS run ENV NODE_ENV=production ENV PORT=$PORT -RUN addgroup --system --gid 1001 nodejs -RUN adduser --system --uid 1001 nextjs +RUN addgroup --gid 1001 nodejs +RUN adduser --uid 1001 --ingroup nodejs --disabled-password --gecos "" nextjs RUN mkdir .next RUN chown nextjs:nodejs .next diff --git a/frontend/next.config.ts b/frontend/next.config.ts index 9f367351..eb104311 100644 --- a/frontend/next.config.ts +++ b/frontend/next.config.ts @@ -2,10 +2,7 @@ import type { NextConfig } from 'next' const nextConfig: NextConfig = { /* config options here */ - eslint: { - // Disable ESLint during builds for production deployment - ignoreDuringBuilds: true, - }, + output: 'standalone', typescript: { // Disable TypeScript errors during builds for production deployment ignoreBuildErrors: true, @@ -18,6 +15,7 @@ const nextConfig: NextConfig = { } return config }, + turbopack: {}, } export default nextConfig diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index d8b93235..e7ff3a26 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -1,7 +1,11 @@ { "compilerOptions": { "target": "ES2017", - "lib": ["dom", "dom.iterable", "esnext"], + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], "allowJs": true, "skipLibCheck": true, "strict": true, @@ -11,7 +15,7 @@ "moduleResolution": "bundler", "resolveJsonModule": true, "isolatedModules": true, - "jsx": "preserve", + "jsx": "react-jsx", "incremental": true, "plugins": [ { @@ -19,9 +23,19 @@ } ], "paths": { - "@/*": ["./*"] + "@/*": [ + "./*" + ] } }, - "include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"], - "exclude": ["node_modules"] + "include": [ + "next-env.d.ts", + "**/*.ts", + "**/*.tsx", + ".next/types/**/*.ts", + ".next/dev/types/**/*.ts" + ], + "exclude": [ + "node_modules" + ] }