diff --git a/.gitignore b/.gitignore index df559c58..c5b3239e 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ yarn-debug.log* yarn-error.log* lerna-debug.log* .pnpm-debug.log* +/planning # Diagnostic reports (https://nodejs.org/api/report.html) report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json diff --git a/DEPLOY.md b/DEPLOY.md index 139c9f5b..c74ca383 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -298,7 +298,7 @@ AZURE_OPENAI_GPT4O_MINI_ENDPOINT=your-gpt4o-mini-endpoint # Storage AZURE_STORAGE_CONNECTION_STRING=your-connection-string -AZURE_STORAGE_CONTAINER_NAME=can-sr-storage +STORAGE_CONTAINER_NAME=can-sr-storage # Authentication SECRET_KEY=your-secret-key-change-in-production @@ -306,7 +306,11 @@ ACCESS_TOKEN_EXPIRE_MINUTES=10080 # Databases (configured in docker-compose.yml) MONGODB_URI=mongodb://sr-mongodb-service:27017/mongodb-sr -POSTGRES_URI=postgres://admin:password@cit-pgdb-service:5432/postgres-cits +POSTGRES_MODE=docker +POSTGRES_HOST=pgdb-service +POSTGRES_DATABASE=postgres +POSTGRES_USER=admin +POSTGRES_PASSWORD=password # Databricks (for database search) DATABRICKS_INSTANCE=your-databricks-instance diff --git a/README.md b/README.md index 11a6f505..7273c941 100644 --- a/README.md +++ b/README.md @@ -221,11 +221,53 @@ AZURE_OPENAI_ENDPOINT=your-endpoint AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o # Storage -AZURE_STORAGE_CONNECTION_STRING=your-connection-string +# STORAGE_MODE is strict: local | azure | entra +STORAGE_MODE=local + +# Storage container name +# - local: folder name under LOCAL_STORAGE_BASE_PATH +# - azure/entra: blob container name +STORAGE_CONTAINER_NAME=can-sr-storage + +# local storage +LOCAL_STORAGE_BASE_PATH=uploads + +# azure storage (account name + key) +# STORAGE_MODE=azure +AZURE_STORAGE_ACCOUNT_NAME=youraccount +AZURE_STORAGE_ACCOUNT_KEY=your-key + +# entra storage (Managed Identity / DefaultAzureCredential) +# STORAGE_MODE=entra +AZURE_STORAGE_ACCOUNT_NAME=youraccount # Databases -MONGODB_URI=mongodb://localhost:27017/mongodb-sr -POSTGRES_URI=postgres://admin:password@localhost:5432/postgres-cits +MONGODB_URI=mongodb://sr-mongodb-service:27017/mongodb-sr + +# Postgres configuration +POSTGRES_MODE=docker # docker | local | azure + +# Canonical Postgres connection settings (single set) +# - docker/local: POSTGRES_PASSWORD is required +# - azure: POSTGRES_PASSWORD is ignored (Entra token auth via DefaultAzureCredential) +POSTGRES_HOST=pgdb-service +POSTGRES_DATABASE=postgres +POSTGRES_USER=admin +POSTGRES_PASSWORD=password + +# Local Postgres (developer machine) +# POSTGRES_MODE=local +# POSTGRES_HOST=localhost +# POSTGRES_DATABASE=grep +# POSTGRES_USER=postgres +# POSTGRES_PASSWORD=123 + +# Azure Database for PostgreSQL (Entra auth) +# POSTGRES_MODE=azure +# POSTGRES_HOST=...postgres.database.azure.com +# POSTGRES_DATABASE=grep +# POSTGRES_USER= +# POSTGRES_PASSWORD= # not used in azure mode # Databricks (for database search) DATABRICKS_INSTANCE=your-instance diff --git a/backend/Dockerfile b/backend/Dockerfile index 6a01e65f..8dd10676 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -9,7 +9,7 @@ WORKDIR /app RUN apt-get clean && \ rm -rf /var/lib/apt/lists/* && \ apt-get update && \ - apt-get install -y --no-install-recommends \ + apt-get install -y \ gcc \ g++ \ git \ @@ -17,6 +17,8 @@ RUN apt-get clean && \ curl \ make \ libc6-dev \ + dialog \ + openssh-server \ && apt-get clean && \ rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* @@ -26,7 +28,7 @@ ENV PIP_DEFAULT_TIMEOUT=300 # Install Python dependencies COPY requirements.txt . -RUN pip install -r requirements.txt +RUN pip install -r requirements.txt --no-cache-dir # Copy application code COPY . . @@ -41,10 +43,7 @@ RUN useradd -m -u 1001 appuser && chown -R appuser:appuser /app COPY sshd_config /etc/ssh/ COPY entrypoint.sh /entrypoint.sh -RUN apt-get update \ - && apt-get install -y --no-install-recommends dialog \ - && apt-get install -y --no-install-recommends openssh-server \ - && echo "root:Docker!" | chpasswd \ +RUN echo "root:Docker!" | chpasswd \ && chmod u+x /entrypoint.sh USER root diff --git a/backend/README.md b/backend/README.md index 1af57f14..7a88eca9 100644 --- a/backend/README.md +++ b/backend/README.md @@ -15,7 +15,7 @@ CAN-SR Backend provides a production-ready REST API for managing systematic revi - **PDF Processing** - Full-text extraction using GROBID - **Azure OpenAI Integration** - GPT-4o, GPT-4o-mini, GPT-3.5-turbo for AI features - **JWT Authentication** - Secure user authentication -- **Azure Blob Storage** - Scalable document storage +- **Storage** - Local filesystem or Azure Blob Storage (connection string or Entra) ## Architecture @@ -51,12 +51,54 @@ AZURE_OPENAI_API_KEY=your-azure-openai-api-key AZURE_OPENAI_ENDPOINT=https://your-resource.openai.azure.com AZURE_OPENAI_DEPLOYMENT_NAME=gpt-4o -# Azure Storage (Required) -AZURE_STORAGE_CONNECTION_STRING=DefaultEndpointsProtocol=https;AccountName=... +# Storage +# STORAGE_MODE is strict: local | azure | entra +STORAGE_MODE=local + +# Storage container name +# - local: folder name under LOCAL_STORAGE_BASE_PATH +# - azure/entra: blob container name +STORAGE_CONTAINER_NAME=can-sr-storage + +# local storage +LOCAL_STORAGE_BASE_PATH=uploads + +# azure storage (account name + key) +# STORAGE_MODE=azure +AZURE_STORAGE_ACCOUNT_NAME=youraccount +AZURE_STORAGE_ACCOUNT_KEY=your-key + +# entra storage (Managed Identity / DefaultAzureCredential) +# STORAGE_MODE=entra +AZURE_STORAGE_ACCOUNT_NAME=youraccount # Databases (Docker defaults - change for production) -MONGODB_URI=mongodb://sr-mongodb-service:27017/mongodb-sr -POSTGRES_URI=postgres://admin:password@cit-pgdb-service:5432/postgres-cits + + +# Postgres configuration +POSTGRES_MODE=docker # docker | local | azure + +# Canonical Postgres connection settings (single set) +# - docker/local: POSTGRES_PASSWORD is required +# - azure: POSTGRES_PASSWORD is ignored (Entra token auth via DefaultAzureCredential) +POSTGRES_HOST=pgdb-service +POSTGRES_DATABASE=postgres +POSTGRES_USER=admin +POSTGRES_PASSWORD=password + +# Local Postgres (developer machine) +# POSTGRES_MODE=local +# POSTGRES_HOST=localhost +# POSTGRES_DATABASE=grep +# POSTGRES_USER=postgres +# POSTGRES_PASSWORD=123 + +# Azure Database for PostgreSQL (Entra auth) +# POSTGRES_MODE=azure +# POSTGRES_HOST= +# POSTGRES_DATABASE= +# POSTGRES_USER= +# POSTGRES_PASSWORD= # not used in azure mode # GROBID Service GROBID_SERVICE_URL=http://grobid-service:8070 @@ -215,14 +257,21 @@ docker compose restart api | `AZURE_OPENAI_API_KEY` | Azure OpenAI API key | `abc123...` | | `AZURE_OPENAI_ENDPOINT` | Azure OpenAI endpoint URL | `https://your-resource.openai.azure.com` | | `AZURE_OPENAI_DEPLOYMENT_NAME` | Model deployment name | `gpt-4o` | -| `AZURE_STORAGE_CONNECTION_STRING` | Azure Blob Storage connection | `DefaultEndpointsProtocol=https;...` | +| `STORAGE_MODE` | Storage backend selector | `local` | +| `LOCAL_STORAGE_BASE_PATH` | Local storage base path (when local) | `uploads` | +| `AZURE_STORAGE_CONNECTION_STRING` | Azure Blob (when STORAGE_MODE=azure) | `DefaultEndpointsProtocol=https;...` | +| `ENTRA_AZURE_STORAGE_ACCOUNT_NAME` | Azure account (when STORAGE_MODE=entra) | `mystorageacct` | | `SECRET_KEY` | JWT token signing key | `your-secure-secret-key` | ### Optional Variables | Variable | Description | Default | |----------|-------------|---------| | `MONGODB_URI` | MongoDB connection string | `mongodb://sr-mongodb-service:27017/mongodb-sr` | -| `POSTGRES_URI` | PostgreSQL connection string | `postgres://admin:password@cit-pgdb-service:5432/postgres-cits` | +| `POSTGRES_MODE` | Postgres connection mode: `docker` \| `local` \| `azure` | `docker` | +| `POSTGRES_HOST` | Postgres host (docker: service name; local: localhost; azure: FQDN) | `pgdb-service` | +| `POSTGRES_DATABASE` | Postgres database name | `postgres` | +| `POSTGRES_USER` | Postgres user (azure: Entra UPN or role) | `admin` | +| `POSTGRES_PASSWORD` | Postgres password (ignored when POSTGRES_MODE=azure) | `password` | | `GROBID_SERVICE_URL` | GROBID service URL | `http://grobid-service:8070` | | `DATABRICKS_INSTANCE` | Databricks workspace URL | - | | `DATABRICKS_TOKEN` | Databricks access token | - | diff --git a/backend/api/citations/router.py b/backend/api/citations/router.py index 247a2859..9c1e0ff9 100644 --- a/backend/api/citations/router.py +++ b/backend/api/citations/router.py @@ -6,7 +6,7 @@ - All endpoints require a Systematic Review (sr_id) that the user is a member of. - Uploading a CSV will: - Parse the CSV - - Create a new Postgres *table* in the shared database (POSTGRES_URI) + - Create a new Postgres *table* in the shared database - Insert citation rows from the CSV into that table - Save the table name + connection string into the Systematic Review record @@ -37,32 +37,25 @@ router = APIRouter() -def _get_db_conn_str() -> Optional[str]: +def _is_postgres_configured() -> bool: """ - Get database connection string for PostgreSQL. - - If POSTGRES_URI is set, returns it directly (local development). - If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER), - returns None to signal that connect_postgres() should use Entra ID authentication. + Check if PostgreSQL is configured via the POSTGRES_MODE profile. """ - if settings.POSTGRES_URI: - return settings.POSTGRES_URI - - # If Entra ID config is available, return None to let connect_postgres use token auth - if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER: - return None - - # No configuration available - return None, let downstream handle the error - return None + try: + prof = settings.postgres_profile() + except Exception: + return False + if not (prof.get("database") and prof.get("user")): + return False -def _is_postgres_configured() -> bool: - """ - Check if PostgreSQL is configured via Entra ID env vars or connection string. - """ - has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER - has_uri_config = settings.POSTGRES_URI - return bool(has_entra_config or has_uri_config) + if prof.get("mode") in ("local", "docker") and not prof.get("password"): + return False + + if prof.get("mode") == "azure" and not prof.get("host"): + return False + + return True class UploadResult(BaseModel): @@ -84,8 +77,8 @@ def _parse_dsn(dsn: str) -> Dict[str, str]: return parse_dsn(dsn) except Exception: return {} -def _create_table_and_insert_sync(db_conn_str: str, table_name: str, columns: List[str], rows: List[Dict[str, Any]]) -> int: - return cits_dp_service.create_table_and_insert_sync(db_conn_str, table_name, columns, rows) +def _create_table_and_insert_sync(table_name: str, columns: List[str], rows: List[Dict[str, Any]]) -> int: + return cits_dp_service.create_table_and_insert_sync(table_name, columns, rows) @router.post("/{sr_id}/upload-csv", response_model=UploadResult) @@ -98,25 +91,23 @@ async def upload_screening_csv( Upload a CSV of citations for title/abstract screening and create a dedicated Postgres table. Requirements: - - POSTGRES_URI must be configured. + - Postgres must be configured via POSTGRES_MODE and POSTGRES_* env vars. - The SR must exist and the user must be a member of the SR (or owner). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, _ = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") - # Check admin DSN (use centralized settings) - need either Entra ID config or POSTGRES_URI + # Check admin config (use centralized settings) if not _is_postgres_configured(): raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Postgres not configured. Set POSTGRES_HOST/DATABASE/USER for Entra ID auth, or POSTGRES_URI for local dev.", + detail="Postgres not configured. Set POSTGRES_MODE and POSTGRES_* env vars.", ) - admin_dsn = _get_db_conn_str() # Read CSV content include_columns = None @@ -141,14 +132,14 @@ async def upload_screening_csv( try: old = (sr.get("screening_db") or {}).get("table_name") if old: - await run_in_threadpool(cits_dp_service.drop_table, admin_dsn, old) + await run_in_threadpool(cits_dp_service.drop_table, old) except Exception: # best-effort only pass # Create table and insert rows in threadpool try: - inserted = await run_in_threadpool(_create_table_and_insert_sync, admin_dsn, table_name, include_columns, normalized_rows) + inserted = await run_in_threadpool(_create_table_and_insert_sync, table_name, include_columns, normalized_rows) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -158,7 +149,6 @@ async def upload_screening_csv( try: screening_info = { "screening_db": { - "connection_string": admin_dsn, "table_name": table_name, "created_at": datetime.utcnow().isoformat(), "rows": inserted, @@ -168,7 +158,6 @@ async def upload_screening_csv( # Update SR document with screening DB info using PostgreSQL await run_in_threadpool( srdb_service.update_screening_db_info, - _get_db_conn_str(), sr_id, screening_info["screening_db"] ) @@ -198,9 +187,8 @@ async def list_citation_ids( Returns a simple list of integers (the 'id' primary key from the citations table). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -212,7 +200,7 @@ async def list_citation_ids( table_name = (screening or {}).get("table_name") or "citations" try: - ids = await run_in_threadpool(cits_dp_service.list_citation_ids, db_conn, filter_step, table_name) + ids = await run_in_threadpool(cits_dp_service.list_citation_ids, filter_step, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -237,9 +225,8 @@ async def get_citation_by_id( Returns: a JSON object representing the citation row (keys are DB column names). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -248,7 +235,7 @@ async def get_citation_by_id( table_name = (screening or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -290,9 +277,8 @@ async def build_combined_citation( the format ": \\n" for each included column, in the order provided. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -304,7 +290,7 @@ async def build_combined_citation( table_name = (screening or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -345,9 +331,8 @@ async def upload_citation_fulltext( to the storage path (container/blob). """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -368,7 +353,7 @@ async def upload_citation_fulltext( table_name = (screening or {}).get("table_name") or "citations" try: - existing_row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + existing_row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -403,7 +388,7 @@ async def upload_citation_fulltext( # Update citation row in Postgres try: - updated = await run_in_threadpool(cits_dp_service.attach_fulltext, db_conn, citation_id, storage_path, content, table_name) + updated = await run_in_threadpool(cits_dp_service.attach_fulltext, citation_id, storage_path, content, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -441,9 +426,8 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An - Caller must be the SR owner. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -456,15 +440,13 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An if not screening: return {"status": "no_screening_db", "message": "No screening table configured for this SR", "deleted_table": False, "deleted_files": 0} - db_conn = None - table_name = screening.get("table_name") if not table_name: return {"status": "no_screening_db", "message": "Incomplete screening DB metadata", "deleted_table": False, "deleted_files": 0} # 1) collect fulltext URLs from the screening DB try: - urls = await run_in_threadpool(cits_dp_service.list_fulltext_urls, db_conn, table_name) + urls = await run_in_threadpool(cits_dp_service.list_fulltext_urls, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -519,12 +501,13 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An parsed_ok = False if not parsed_ok: - # fallback: try direct deletion via blob client (best-effort) + # fallback: delete by storage path (works for both azure/local) if storage_service: try: - blob_client = storage_service.blob_service_client.get_blob_client(container=container, blob=blob) - blob_client.delete_blob() + await storage_service.delete_by_path(f"{container}/{blob}") deleted_files += 1 + except FileNotFoundError: + failed_files += 1 except Exception: failed_files += 1 else: @@ -535,7 +518,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An # 3) drop the screening table try: - await run_in_threadpool(cits_dp_service.drop_table, db_conn, table_name) + await run_in_threadpool(cits_dp_service.drop_table, table_name) table_dropped = True except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) @@ -546,7 +529,6 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An try: await run_in_threadpool( srdb_service.clear_screening_db_info, - _get_db_conn_str(), sr_id ) except Exception: @@ -582,10 +564,9 @@ async def export_citations_csv( Content-Disposition. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check( - sr_id, current_user, db_conn_str, srdb_service + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service ) except HTTPException: raise @@ -595,7 +576,7 @@ async def export_citations_csv( detail=f"Failed to load systematic review: {e}", ) - if not screening or not db_conn: + if not screening: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review", @@ -604,7 +585,7 @@ async def export_citations_csv( table_name = (screening or {}).get("table_name") or "citations" try: - csv_bytes = await run_in_threadpool(cits_dp_service.dump_citations_csv, db_conn, table_name) + csv_bytes = await run_in_threadpool(cits_dp_service.dump_citations_csv, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: diff --git a/backend/api/core/bounding_box_matcher.py b/backend/api/core/bounding_box_matcher.py new file mode 100644 index 00000000..a0fe5e4d --- /dev/null +++ b/backend/api/core/bounding_box_matcher.py @@ -0,0 +1,435 @@ +import re +from typing import Dict, Any, List, Optional, Tuple + + +def normalize_text(text: str) -> str: + """Normalize text for comparison (remove extra whitespace, newlines, etc.)""" + # Replace multiple whitespace/newlines with single space + text = re.sub(r"\s+", " ", text) + # Remove special characters that might differ + text = text.strip() + return text + + +def find_text_in_paragraphs( + search_text: str, paragraphs: List[Dict[str, Any]], threshold: float = 0.6 +) -> List[Dict[str, Any]]: + """ + Find matching paragraphs for the search text. + + Args: + search_text: Text to search for + paragraphs: List of paragraph dictionaries from raw_analysis.json + threshold: Minimum similarity threshold (0-1) + + Returns: + List of matching paragraphs with bounding box information + """ + normalized_search = normalize_text(search_text) + matches = [] + + for para in paragraphs: + para_content = para.get("content", "") + normalized_para = normalize_text(para_content) + + # Check for substring match (search text in paragraph or paragraph in search) + if normalized_search in normalized_para or normalized_para in normalized_search: + # Calculate similarity using character overlap + search_chars = set(normalized_search.lower()) + para_chars = set(normalized_para.lower()) + overlap = len(search_chars & para_chars) + similarity = overlap / max(len(search_chars), len(para_chars), 1) + + if similarity >= threshold: + match_info = { + "paragraph_content": para_content, + "similarity": similarity, + "bounding_regions": para.get("boundingRegions", []), + "role": para.get("role"), + "spans": para.get("spans", []), + } + matches.append(match_info) + # Also check for partial matches (if search text is long, check if significant portion matches) + elif len(normalized_search) > 20: + # Try to find if a significant portion of the search text appears in the paragraph + search_words = set(normalized_search.lower().split()) + para_words = set(normalized_para.lower().split()) + if len(search_words) > 0: + word_overlap = len(search_words & para_words) + word_similarity = word_overlap / len(search_words) + if word_similarity >= 0.5: # At least 50% of words match + match_info = { + "paragraph_content": para_content, + "similarity": word_similarity, + "bounding_regions": para.get("boundingRegions", []), + "role": para.get("role"), + "spans": para.get("spans", []), + } + matches.append(match_info) + + return matches + + +def find_text_in_lines( + search_text: str, pages: List[Dict[str, Any]], threshold: float = 0.5 +) -> List[Dict[str, Any]]: + """ + Find matching lines for the search text. + + Args: + search_text: Text to search for + pages: List of page dictionaries from raw_analysis.json + threshold: Minimum similarity threshold + + Returns: + List of matching lines with bounding box information + """ + normalized_search = normalize_text(search_text) + matches = [] + + for page in pages: + page_number = page.get("pageNumber", 0) + lines = page.get("lines", []) + + for line in lines: + line_content = line.get("content", "") + normalized_line = normalize_text(line_content) + + # Check for substring match + if ( + normalized_search in normalized_line + or normalized_line in normalized_search + ): + search_chars = set(normalized_search.lower()) + line_chars = set(normalized_line.lower()) + overlap = len(search_chars & line_chars) + similarity = overlap / max(len(search_chars), len(line_chars), 1) + + if similarity >= threshold: + match_info = { + "line_content": line_content, + "page_number": page_number, + "similarity": similarity, + "polygon": line.get("polygon", []), + "spans": line.get("spans", []), + } + matches.append(match_info) + # Also check word-level matching for longer texts + elif len(normalized_search) > 15: + search_words = set(normalized_search.lower().split()) + line_words = set(normalized_line.lower().split()) + if len(search_words) > 0: + word_overlap = len(search_words & line_words) + word_similarity = word_overlap / len(search_words) + if word_similarity >= 0.4: + match_info = { + "line_content": line_content, + "page_number": page_number, + "similarity": word_similarity, + "polygon": line.get("polygon", []), + "spans": line.get("spans", []), + } + matches.append(match_info) + + return matches + + +def match_reference_to_bounding_box( + reference_text: str, + raw_analysis: Dict[str, Any], + para_threshold: float = 0.6, + line_threshold: float = 0.5, +) -> Dict[str, Any]: + """ + Match a single reference text to bounding boxes in the raw analysis. + + Args: + reference_text: The text excerpt to match + raw_analysis: The complete raw analysis dictionary from Azure Document Intelligence + para_threshold: Minimum similarity threshold for paragraph matching + line_threshold: Minimum similarity threshold for line matching + + Returns: + Dictionary with matching information including bounding boxes + """ + paragraphs = raw_analysis.get("paragraphs", []) + pages = raw_analysis.get("pages", []) + + # Try to find in paragraphs first (more accurate) + para_matches = find_text_in_paragraphs(reference_text, paragraphs, para_threshold) + + # Also try to find in lines (for more granular matching) + line_matches = find_text_in_lines(reference_text, pages, line_threshold) + + result = { + "text": reference_text, + "paragraph_matches": para_matches, + "line_matches": line_matches, + "best_match": None, + } + + # Determine best match (prefer paragraph matches with highest similarity) + if para_matches: + best_para = max(para_matches, key=lambda x: x["similarity"]) + bounding_regions = best_para["bounding_regions"] + # Format bounding box info + bbox_info = [] + page_number = None + for region in bounding_regions: + page_num = region.get("pageNumber", region.get("page_number")) + if page_num and not page_number: + page_number = page_num + bbox_info.append( + {"page_number": page_num, "polygon": region.get("polygon", [])} + ) + + result["best_match"] = { + "type": "paragraph", + "content": best_para["paragraph_content"], + "similarity": best_para["similarity"], + "page_number": page_number, # Add page_number at top level for easy access + "bounding_regions": bbox_info, + "role": best_para.get("role"), + } + elif line_matches: + best_line = max(line_matches, key=lambda x: x["similarity"]) + result["best_match"] = { + "type": "line", + "content": best_line["line_content"], + "similarity": best_line["similarity"], + "page_number": best_line["page_number"], + "polygon": best_line["polygon"], + } + + return result + + +def match_references_to_bounding_boxes( + references: List[Dict[str, Any]], + raw_analysis: Dict[str, Any], + para_threshold: float = 0.6, + line_threshold: float = 0.5, +) -> List[Dict[str, Any]]: + """ + Match multiple reference texts to bounding boxes in the raw analysis. + + Args: + references: List of reference dictionaries, each with at least a 'text' field + raw_analysis: The complete raw analysis dictionary from Azure Document Intelligence + para_threshold: Minimum similarity threshold for paragraph matching + line_threshold: Minimum similarity threshold for line matching + + Returns: + List of dictionaries with matching information for each reference + """ + results = [] + + for ref_idx, ref in enumerate(references): + ref_text = ref.get("text", "") + if not ref_text: + continue + + match_result = match_reference_to_bounding_box( + ref_text, raw_analysis, para_threshold, line_threshold + ) + + # Add reference index and preserve any additional fields from the original reference + match_result["reference_index"] = ref_idx + if "context" in ref: + match_result["context"] = ref["context"] + + results.append(match_result) + + return results + + +def extract_figure_references(text: str) -> List[Tuple[str, str]]: + """ + Extract figure references from text using regex patterns. + + Args: + text: Text to search for figure references + + Returns: + List of tuples (figure_reference, figure_id) like [("Figure 1.1", "1.1"), ("Fig. 2.3", "2.3")] + """ + # Patterns to match figure references + patterns = [ + r"\bFigure\s+(\d+(?:\.\d+)*)", # "Figure 1.1", "Figure 2" + r"\bFig\.?\s+(\d+(?:\.\d+)*)", # "Fig 1.1", "Fig. 2.3" + r"\bFIGURE\s+(\d+(?:\.\d+)*)", # "FIGURE 1.1" + r"\bFIG\.?\s+(\d+(?:\.\d+)*)", # "FIG 1.1" + ] + + figure_refs = [] + for pattern in patterns: + matches = re.findall(pattern, text, re.IGNORECASE) + for match in matches: + # Create the full reference text and the figure ID + if "Figure" in pattern.upper(): + full_ref = f"Figure {match}" + elif "Fig." in pattern: + full_ref = f"Fig. {match}" + elif "FIG." in pattern: + full_ref = f"FIG. {match}" + else: + full_ref = f"Fig {match}" + figure_refs.append((full_ref, match)) + + return figure_refs + + +def find_figure_by_id( + figure_id: str, figures: List[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + """ + Find a figure by its ID in the figures list. + Includes fallback logic for mismatched figure numbering between document text and Azure Doc Intelligence. + + Args: + figure_id: Figure ID to search for (e.g., "1.1", "2.3", or just "1", "2") + figures: List of figure dictionaries + + Returns: + Figure dictionary if found, None otherwise + """ + # First try exact match + for figure in figures: + if figure.get("id") == figure_id: + return figure + + # If no exact match, try fallback strategies + # Strategy 1: If figure_id is just a number (like "1"), try to find figures that start with that number + if figure_id.isdigit(): + target_num = int(figure_id) + for figure in figures: + fig_id = figure.get("id", "") + # Try patterns like "1.1", "1.2", "2.1", etc. + if fig_id.startswith(f"{target_num}.") or fig_id == str(target_num): + print(f"[FIGURE_MATCHING] Fallback match: '{figure_id}' -> '{fig_id}'") + return figure + + # Strategy 2: Sequential mapping - Azure often assigns IDs like "1.1", "2.1", "3.1" sequentially + # If document says "Fig. 1" and "Fig. 2", map them to the first N figures in order + if figure_id.isdigit(): + target_num = int(figure_id) + # Sort figures by their ID to get sequential order + sorted_figures = sorted(figures, key=lambda f: f.get("id", "")) + + # If we have fewer figures than the target number, try direct indexing + if target_num <= len(sorted_figures): + selected_figure = sorted_figures[target_num - 1] # 1-indexed to 0-indexed + print( + f"[FIGURE_MATCHING] Sequential mapping: '{figure_id}' -> figure at index {target_num - 1} (ID: {selected_figure.get('id')})" + ) + return selected_figure + + # Strategy 3: Try to match by page proximity and caption similarity + # For documents where figure numbering doesn't match, try to find figures on the same page + # This is a more advanced strategy that could be implemented later + + print( + f"[FIGURE_MATCHING] No figure found for ID: '{figure_id}' (checked {len(figures)} figures)" + ) + return None + + +def match_figure_references_to_bounding_boxes( + references: List[Dict[str, Any]], + raw_analysis: Dict[str, Any], + figures: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: + """ + Match figure references to figure bounding boxes in addition to text references. + Enhanced to ensure all references get bounding boxes. + + Args: + references: List of reference dictionaries from LLM responses + raw_analysis: Raw analysis from Azure Document Intelligence + figures: List of figure metadata dictionaries + + Returns: + Updated references list with both text and figure bounding box matches + """ + print( + f"[FIGURE_MATCHING] Starting enhanced figure reference matching with {len(figures)} figures" + ) + figure_ids = [f.get("id", "unknown") for f in figures] + print(f"[FIGURE_MATCHING] Available figure IDs: {figure_ids}") + + # First, run text matching on ALL references to ensure they all get bounding boxes + text_matched_refs = match_references_to_bounding_boxes(references, raw_analysis) + print( + f"[FIGURE_MATCHING] Text matching completed for {len(text_matched_refs)} references" + ) + + # Now enhance references that contain figure references with figure information + enhanced_references = [] + + for ref in text_matched_refs: + ref_text = ref.get("text", "") + figure_refs = extract_figure_references(ref_text) + + print(f"[FIGURE_MATCHING] Processing reference: '{ref_text[:100]}...'") + print(f"[FIGURE_MATCHING] Extracted figure refs: {figure_refs}") + + if figure_refs: + # This reference contains figure references + # Enhance the existing text-matched reference with figure information + enhanced_ref = dict(ref) # Copy the text-matched reference + + # Add figure information to the best_match + if ref.get("best_match"): + enhanced_best_match = dict(ref["best_match"]) + # Add figure metadata to the existing best_match + for full_ref, figure_id in figure_refs: + print(f"[FIGURE_MATCHING] Looking for figure ID: '{figure_id}'") + figure = find_figure_by_id(figure_id, figures) + if figure: + print( + f"[FIGURE_MATCHING] ✅ Found matching figure: {figure.get('id')} on page {figure.get('page')}" + ) + # Add figure information to the best_match + enhanced_best_match["figure_id"] = figure_id + enhanced_best_match["figure_reference"] = full_ref + enhanced_best_match["figure_caption"] = figure.get("caption") + # Keep the text bounding box but add figure type information + enhanced_best_match["has_figure_reference"] = True + break # Only handle the first figure reference found + + enhanced_ref["best_match"] = enhanced_best_match + else: + # No text match found, but figure reference exists - create figure-only reference + for full_ref, figure_id in figure_refs: + figure = find_figure_by_id(figure_id, figures) + if figure: + print( + f"[FIGURE_MATCHING] Creating figure-only reference for: '{figure_id}'" + ) + enhanced_ref["best_match"] = { + "type": "figure", + "similarity": 1.0, + "page_number": figure.get("page"), + "bounding_regions": figure.get("bounding_regions", []), + "polygon": ( + figure.get("bounding_regions", [{}])[0].get( + "polygon", [] + ) + if figure.get("bounding_regions") + else [] + ), + "caption": figure.get("caption"), + "figure_id": figure_id, + "figure_reference": full_ref, + } + break + + enhanced_references.append(enhanced_ref) + else: + # No figure references, keep the text-matched reference as-is + enhanced_references.append(ref) + + print( + f"[FIGURE_MATCHING] Enhanced {len(enhanced_references)} references with figure information" + ) + return enhanced_references \ No newline at end of file diff --git a/backend/api/core/cit_utils.py b/backend/api/core/cit_utils.py index 12442fa8..1e33ce13 100644 --- a/backend/api/core/cit_utils.py +++ b/backend/api/core/cit_utils.py @@ -16,19 +16,33 @@ from .config import settings -def _is_postgres_configured(db_conn_str: Optional[str] = None) -> bool: +def _is_postgres_configured() -> bool: """ - Check if PostgreSQL is configured via Entra ID env vars or connection string. + Check if PostgreSQL is configured via the POSTGRES_MODE profile. """ - has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER - has_uri_config = db_conn_str or settings.POSTGRES_URI - return bool(has_entra_config or has_uri_config) + try: + prof = settings.postgres_profile() + except Exception: + return False + + # minimal requirements + if not (prof.get("database") and prof.get("user")): + return False + + # password required for local/docker + if prof.get("mode") in ("local", "docker") and not prof.get("password"): + return False + + # azure requires host + if prof.get("mode") == "azure" and not prof.get("host"): + return False + + return True async def load_sr_and_check( sr_id: str, current_user: Dict[str, Any], - db_conn_str: Optional[str], srdb_service, require_screening: bool = True, require_visible: bool = True, @@ -39,20 +53,19 @@ async def load_sr_and_check( Args: sr_id: SR id string current_user: current user dict (must contain "id" and "email") - db_conn_str: PostgreSQL connection string (can be None if using Entra ID auth) srdb_service: SR DB service instance (must implement get_systematic_review and user_has_sr_permission) require_screening: if True, also ensure the SR has a configured screening_db and return its connection string require_visible: if True, require the SR 'visible' flag to be True; set False for endpoints like hard-delete Returns: - (sr_doc, screening_obj or None, db_conn_string or None) + (sr_doc, screening_obj or None) Raises HTTPException with appropriate status codes on failure so routers can just propagate. """ # fetch SR try: - sr = await run_in_threadpool(srdb_service.get_systematic_review, db_conn_str, sr_id, not require_visible) + sr = await run_in_threadpool(srdb_service.get_systematic_review, sr_id, not require_visible) except HTTPException: raise except Exception as e: @@ -67,7 +80,7 @@ async def load_sr_and_check( # permission check (user must be member or owner) user_id = current_user.get("email") try: - has_perm = await run_in_threadpool(srdb_service.user_has_sr_permission, db_conn_str, sr_id, user_id) + has_perm = await run_in_threadpool(srdb_service.user_has_sr_permission, sr_id, user_id) except HTTPException: raise except Exception as e: @@ -77,10 +90,9 @@ async def load_sr_and_check( raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view/modify this systematic review") screening = sr.get("screening_db") if isinstance(sr, dict) else None - db_conn = None if require_screening: if not screening: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review") - db_conn = None - return sr, screening, db_conn + + return sr, screening diff --git a/backend/api/core/config.py b/backend/api/core/config.py index 3fb7d32a..37796fb1 100644 --- a/backend/api/core/config.py +++ b/backend/api/core/config.py @@ -1,4 +1,5 @@ import os +from pathlib import Path from typing import List, Optional from pydantic import Field, field_validator from pydantic_settings import BaseSettings @@ -21,22 +22,30 @@ class Settings(BaseSettings): DESCRIPTION: str = os.getenv( "DESCRIPTION", "AI-powered systematic review platform for Government of Canada" ) - IS_DEPLOYED: bool = os.getenv("IS_DEPLOYED") + IS_DEPLOYED: bool = os.getenv("IS_DEPLOYED", "false").lower() == "true" # CORS CORS_ORIGINS: str = os.getenv("CORS_ORIGINS", "*") # Storage settings - STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "azure") - AZURE_STORAGE_ACCOUNT_NAME: Optional[str] = os.getenv( - "AZURE_STORAGE_ACCOUNT_NAME" - ) - AZURE_STORAGE_CONNECTION_STRING: Optional[str] = os.getenv( - "AZURE_STORAGE_CONNECTION_STRING" - ) - AZURE_STORAGE_CONTAINER_NAME: str = os.getenv( - "AZURE_STORAGE_CONTAINER_NAME", "can-sr-storage" - ) + # Storage selection (strict): local | azure | entra + STORAGE_MODE: str = os.getenv("STORAGE_MODE", "azure").lower().strip() + # Canonical storage container name used across all storage types. + # - local: folder name under LOCAL_STORAGE_BASE_PATH + # - azure/entra: blob container name + STORAGE_CONTAINER_NAME: str = os.getenv("STORAGE_CONTAINER_NAME", "can-sr-storage") + # Azure Storage + # - STORAGE_MODE=azure requires account name + account key + # - STORAGE_MODE=entra requires only account name (uses DefaultAzureCredential) + AZURE_STORAGE_ACCOUNT_NAME: Optional[str] = os.getenv("AZURE_STORAGE_ACCOUNT_NAME") + AZURE_STORAGE_ACCOUNT_KEY: Optional[str] = os.getenv("AZURE_STORAGE_ACCOUNT_KEY") + + # Local storage settings (used when STORAGE_MODE=local) + # In docker, default path is backed by the compose volume: ./uploads:/app/uploads + # Default to a relative directory so it works both locally and in docker: + # - locally: /backend/uploads + # - in docker: /app/uploads + LOCAL_STORAGE_BASE_PATH: str = os.getenv("LOCAL_STORAGE_BASE_PATH", "uploads") # File upload settings MAX_FILE_SIZE: int = Field(default=52428800) # 50MB in bytes @@ -60,40 +69,11 @@ def convert_max_file_size(cls, v): # Azure OpenAI settings (Primary - GPT-4o) AZURE_OPENAI_API_KEY: Optional[str] = os.getenv("AZURE_OPENAI_API_KEY") AZURE_OPENAI_ENDPOINT: Optional[str] = os.getenv("AZURE_OPENAI_ENDPOINT") - AZURE_OPENAI_API_VERSION: str = os.getenv( - "AZURE_OPENAI_API_VERSION", "2025-01-01-preview" - ) - AZURE_OPENAI_DEPLOYMENT_NAME: str = os.getenv( - "AZURE_OPENAI_DEPLOYMENT_NAME", "gpt-4o" - ) - - # GPT-4.1-mini configuration - AZURE_OPENAI_GPT41_MINI_API_KEY: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT41_MINI_API_KEY" - ) - AZURE_OPENAI_GPT41_MINI_ENDPOINT: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT41_MINI_ENDPOINT" - ) - AZURE_OPENAI_GPT41_MINI_DEPLOYMENT: str = os.getenv( - "AZURE_OPENAI_GPT41_MINI_DEPLOYMENT", "gpt-4.1-mini" - ) - AZURE_OPENAI_GPT41_MINI_API_VERSION: str = os.getenv( - "AZURE_OPENAI_GPT41_MINI_API_VERSION", "2025-01-01-preview" - ) - - # GPT-5-mini configuration - AZURE_OPENAI_GPT5_MINI_API_KEY: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT5_MINI_API_KEY" - ) - AZURE_OPENAI_GPT5_MINI_ENDPOINT: Optional[str] = os.getenv( - "AZURE_OPENAI_GPT5_MINI_ENDPOINT" - ) - AZURE_OPENAI_GPT5_MINI_DEPLOYMENT: str = os.getenv( - "AZURE_OPENAI_GPT5_MINI_DEPLOYMENT", "gpt-5-mini" - ) - AZURE_OPENAI_GPT5_MINI_API_VERSION: str = os.getenv( - "AZURE_OPENAI_GPT5_MINI_API_VERSION", "2025-08-07" - ) + # Azure OpenAI auth selection + # key -> uses AZURE_OPENAI_ENDPOINT + AZURE_OPENAI_API_KEY + # entra -> uses AZURE_OPENAI_ENDPOINT + DefaultAzureCredential + # Backwards/alternate env var: OPENAI_TYPE + AZURE_OPENAI_MODE: str = os.getenv("AZURE_OPENAI_MODE", "key").lower().strip() # Default model to use DEFAULT_CHAT_MODEL: str = os.getenv("DEFAULT_CHAT_MODEL", "gpt-5-mini") @@ -110,37 +90,80 @@ def convert_max_file_size(cls, v): ) DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true" - # Database and external system environment variables - # Postgres settings for Entra ID authentication + # ------------------------------------------------------------------------- + # Postgres configuration + # ------------------------------------------------------------------------- + # Select primary Postgres mode. + # docker/local use password auth; azure uses Entra token auth. + POSTGRES_MODE: str = os.getenv("POSTGRES_MODE", "docker").lower().strip() # docker|local|azure + + # Canonical Postgres connection settings (single profile; values vary by environment) POSTGRES_HOST: Optional[str] = os.getenv("POSTGRES_HOST") POSTGRES_DATABASE: Optional[str] = os.getenv("POSTGRES_DATABASE") - POSTGRES_USER: Optional[str] = os.getenv("POSTGRES_USER") # Entra ID user (e.g., user@tenant.onmicrosoft.com) - POSTGRES_PORT: int = int(os.getenv("POSTGRES_PORT", "5432")) - POSTGRES_SSL_MODE: Optional[str] = os.getenv("POSTGRES_SSL_MODE") + POSTGRES_USER: Optional[str] = os.getenv("POSTGRES_USER") POSTGRES_PASSWORD: Optional[str] = os.getenv("POSTGRES_PASSWORD") - AZURE_DB: bool = os.getenv("AZURE_DB", "false").lower() == "true" - # Legacy: Postgres DSN used for systematic reviews and screening databases (fallback) + + # Optional overrides + POSTGRES_PORT: int = int(os.getenv("POSTGRES_PORT", "5432")) + POSTGRES_SSL_MODE: str = os.getenv("POSTGRES_SSL_MODE", "require") + + # Deprecated (will be removed): legacy Postgres DSN POSTGRES_URI: Optional[str] = os.getenv("POSTGRES_URI") + def postgres_profile(self, mode: Optional[str] = None) -> dict: + """Return resolved Postgres connection settings for a specific mode. + + The application uses a single set of environment variables: + POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER, POSTGRES_PASSWORD. + + POSTGRES_MODE controls *how* authentication is performed: + - docker/local: password auth (POSTGRES_PASSWORD required) + - azure: Entra token auth (password ignored) + sslmode=require + """ + m = (mode or self.POSTGRES_MODE or "").lower().strip() + if m not in {"docker", "local", "azure"}: + raise ValueError("POSTGRES_MODE must be one of: docker, local, azure") + + # Provide sensible defaults for host depending on mode. + default_host = "pgdb-service" if m == "docker" else "localhost" if m == "local" else None + + prof = { + "mode": m, + "host": self.POSTGRES_HOST or default_host, + "database": self.POSTGRES_DATABASE, + "user": self.POSTGRES_USER, + # For azure we intentionally ignore password (token auth) + "password": None if m == "azure" else self.POSTGRES_PASSWORD, + "port": self.POSTGRES_PORT, + "sslmode": (self.POSTGRES_SSL_MODE or "require") if m == "azure" else None, + } + + return prof + + def has_local_fallback(self) -> bool: + # Deprecated: legacy behavior. Kept for compatibility with older code paths. + return False + # Databricks settings - DATABRICKS_INSTANCE: str = os.getenv("DATABRICKS_INSTANCE") - DATABRICKS_TOKEN: str = os.getenv("DATABRICKS_TOKEN") - JOB_ID_EUROPEPMC: str = os.getenv("JOB_ID_EUROPEPMC") - JOB_ID_PUBMED: str = os.getenv("JOB_ID_PUBMED") - JOB_ID_SCOPUS: str = os.getenv("JOB_ID_SCOPUS") + DATABRICKS_INSTANCE: str = os.getenv("DATABRICKS_INSTANCE", "") + DATABRICKS_TOKEN: str = os.getenv("DATABRICKS_TOKEN", "") + JOB_ID_EUROPEPMC: str = os.getenv("JOB_ID_EUROPEPMC", "") + JOB_ID_PUBMED: str = os.getenv("JOB_ID_PUBMED", "") + JOB_ID_SCOPUS: str = os.getenv("JOB_ID_SCOPUS", "") # OAuth - OAUTH_CLIENT_ID: str = os.getenv("OAUTH_CLIENT_ID") - OAUTH_CLIENT_SECRET: str = os.getenv("OAUTH_CLIENT_SECRET") - REDIRECT_URI: str = os.getenv("REDIRECT_URI") - SSO_LOGIN_URL: str = os.getenv("SSO_LOGIN_URL") + OAUTH_CLIENT_ID: str = os.getenv("OAUTH_CLIENT_ID", "") + OAUTH_CLIENT_SECRET: str = os.getenv("OAUTH_CLIENT_SECRET", "") + REDIRECT_URI: str = os.getenv("REDIRECT_URI", "") + SSO_LOGIN_URL: str = os.getenv("SSO_LOGIN_URL", "") # Entra USE_ENTRA_AUTH: bool = os.getenv("USE_ENTRA_AUTH", "false").lower() == "true" class Config: case_sensitive = True - env_file = ".env" + # Resolve to backend/.env regardless of current working directory. + env_file = str(Path(__file__).resolve().parents[2] / ".env") extra = "ignore" # Allow extra environment variables diff --git a/backend/api/core/docint_coords.py b/backend/api/core/docint_coords.py new file mode 100644 index 00000000..564be3de --- /dev/null +++ b/backend/api/core/docint_coords.py @@ -0,0 +1,126 @@ +"""Helpers to normalize Azure Document Intelligence polygons into CAN-SR/Grobid-style boxes. + +Frontend (PDFBoundingBoxViewer) expects per-page boxes with fields: + { page, x, y, width, height } + +Grobid coords are in TEI pixel space from the original PDF rendering. +Azure DI returns polygons in page units (typically "inch" or "pixel"). + +We normalize by: + 1) Converting Azure units to pixel space when possible (inch->72dpi pixels). + 2) Converting polygon -> axis-aligned bounding rect. + +This is best-effort and is intended for highlighting tables/figures similar +to sentence highlighting. +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Tuple + + +def _page_meta_by_number(pages_meta: Any) -> Dict[int, Dict[str, Any]]: + out: Dict[int, Dict[str, Any]] = {} + if not isinstance(pages_meta, list): + return out + for p in pages_meta: + if not isinstance(p, dict): + continue + num = p.get("pageNumber") or p.get("page_number") + try: + num_i = int(num) + except Exception: + continue + out[num_i] = p + return out + + +def _unit_to_scale(unit: Optional[str]) -> float: + """Return multiplier to convert unit coordinates into ~PDF pixels. + + Azure DI commonly reports: + - unit == 'inch' + - unit == 'pixel' + - unit == None + + We assume PDF coordinate space is 72 dpi points. + """ + + if not unit: + return 1.0 + u = str(unit).strip().lower() + if u in ("pixel", "pixels", "px"): + return 1.0 + if u in ("inch", "in"): + return 72.0 + # Unknown units: do not scale. + return 1.0 + + +def polygon_to_bbox(polygon: Any) -> Optional[Tuple[float, float, float, float]]: + """Convert Azure polygon [x1,y1,x2,y2,...] to (minx, miny, maxx, maxy).""" + if not isinstance(polygon, list) or len(polygon) < 8: + return None + try: + xs = [float(polygon[i]) for i in range(0, len(polygon), 2)] + ys = [float(polygon[i]) for i in range(1, len(polygon), 2)] + except Exception: + return None + return (min(xs), min(ys), max(xs), max(ys)) + + +def normalize_bounding_regions_to_boxes( + bounding_regions: Any, + pages_meta: Any, +) -> List[Dict[str, Any]]: + """Normalize Azure boundingRegions -> list of {page,x,y,width,height}. + + bounding_regions supports either: + - [{'pageNumber'|'page_number': 1, 'polygon': [...]}, ...] + - [{'page_number': 1, 'polygon': [...]}, ...] + """ + out: List[Dict[str, Any]] = [] + if not isinstance(bounding_regions, list): + return out + + pm = _page_meta_by_number(pages_meta) + + for region in bounding_regions: + if not isinstance(region, dict): + continue + + page = region.get("pageNumber") + if page is None: + page = region.get("page_number") + try: + page_i = int(page) + except Exception: + continue + + poly = region.get("polygon") + bbox = polygon_to_bbox(poly) + if not bbox: + continue + + unit = None + if page_i in pm: + unit = pm[page_i].get("unit") + + s = _unit_to_scale(unit) + minx, miny, maxx, maxy = bbox + minx *= s + miny *= s + maxx *= s + maxy *= s + + out.append( + { + "page": page_i, + "x": minx, + "y": miny, + "width": max(0.0, maxx - minx), + "height": max(0.0, maxy - miny), + } + ) + + return out diff --git a/backend/api/core/security.py b/backend/api/core/security.py index 6b584dc2..508cfcb0 100644 --- a/backend/api/core/security.py +++ b/backend/api/core/security.py @@ -60,7 +60,7 @@ async def create_user(user_data: UserCreate) -> Optional[UserRead]: if not user_db_service: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="User registration is not available without Azure Storage configuration", + detail="User registration is not available without configured storage", ) return await user_db_service.create_user(user_data) @@ -71,7 +71,7 @@ async def update_user(user_id: str, user_in: UserUpdate) -> Optional[UserRead]: if not user_db_service: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="User operations are not available without Azure Storage configuration", + detail="User operations are not available without configured storage", ) update_data = user_in.model_dump(exclude_unset=True) diff --git a/backend/api/extract/prompts.py b/backend/api/extract/prompts.py index e71821ab..447aca76 100644 --- a/backend/api/extract/prompts.py +++ b/backend/api/extract/prompts.py @@ -2,6 +2,7 @@ You are an expert information extractor for scientific full-text articles. You will be given: - A short description of a parameter to extract (what the parameter is and how it is defined). - The full text of a paper with each sentence numbered like: [0] First sentence. [1] Second sentence. etc. +- Optionally, numbered tables (as markdown) and numbered figure captions (with the corresponding figure images provided alongside this message). Task (STRICT): Return a single valid JSON object and nothing else. The JSON MUST contain the following keys: @@ -9,6 +10,8 @@ - "value": the extracted value as a string (or null if not found). - "explanation": a concise explanation (1-4 sentences) describing why this value was chosen or how it was derived. - "evidence_sentences": an array of integers indicating the sentence indices you used as evidence (e.g. [2, 5]). If there are no supporting sentences, return an empty array. +- "evidence_tables": an array of integers indicating table numbers used (e.g. [1, 2]) or []. +- "evidence_figures": an array of integers indicating figure numbers used (e.g. [3]) or []. Requirements: - If the parameter is explicitly present, return the value exactly as found (preserve units/format) and list the sentence indices. @@ -21,9 +24,16 @@ - {parameter_name} (a short name for the parameter) - {parameter_description} (detailed description of what to look for) - {fulltext} (the numbered sentences string; e.g. "[0] First sentence\n[1] Next sentence\n...") +- {tables} (numbered markdown tables) +- {figures} (numbered figure captions) Example valid output: -{{"found": true, "value": "5 mg/kg", "explanation": "The Methods section explicitly lists a dose of 5 mg/kg in sentence [12].", "evidence_sentences": [12]}} +{{"found": true, "value": "5 mg/kg", "explanation": "The Methods section explicitly lists a dose of 5 mg/kg in sentence [12].", "evidence_sentences": [12], "evidence_tables": [], "evidence_figures": []}} Do not output anything besides the JSON object. +\n\nParameter name: {parameter_name} +\nParameter description: {parameter_description} +\n\nFull text (numbered sentences):\n{fulltext} +\n\nTables (numbered):\n{tables} +\n\nFigures (numbered; captions correspond to images provided alongside this message):\n{figures} """ diff --git a/backend/api/extract/router.py b/backend/api/extract/router.py index a2b1ff91..3237a46b 100644 --- a/backend/api/extract/router.py +++ b/backend/api/extract/router.py @@ -1,10 +1,12 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import json import re import os from tempfile import NamedTemporaryFile import hashlib from datetime import datetime +from pathlib import Path +import asyncio from fastapi import APIRouter, Depends, HTTPException, status from fastapi.concurrency import run_in_threadpool @@ -25,6 +27,9 @@ from ..services.grobid_service import grobid_service from ..core.cit_utils import load_sr_and_check +from ..services.azure_docint_client import azure_docint_client +from ..core.docint_coords import normalize_bounding_regions_to_boxes + # Import consolidated Postgres helpers if available (optional) from ..services.cit_db_service import cits_dp_service, snake_case_param @@ -44,6 +49,11 @@ class ParameterExtractRequest(BaseModel): temperature: Optional[float] = Field(0.0, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(512, ge=1, le=4000) + # Optional artifacts context (if omitted, server will read from citation row when available) + tables: Optional[str] = Field(None, description="Optional numbered tables text (markdown).") + figures: Optional[str] = Field(None, description="Optional numbered figure captions text.") + attach_figures: Optional[bool] = Field(True, description="If true, attach figure images to the LLM request when available") + @@ -83,9 +93,8 @@ async def extract_parameter_endpoint( derived from the parameter name (prefixed with 'llm_param_'). """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -98,7 +107,7 @@ async def extract_parameter_endpoint( row = None if not fulltext: try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -111,24 +120,108 @@ async def extract_parameter_endpoint( if not fulltext: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Full text not provided and not available for this citation") + # Build tables/figures context: prefer payload fields, else fetch from DB row (if loaded) + tables_text = payload.tables + figures_text = payload.figures + images: List[Tuple[bytes, str]] = [] + + if (tables_text is None or figures_text is None) and row is None: + try: + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + except Exception: + row = None + + if row and (tables_text is None or figures_text is None): + # tables: embed markdown + if tables_text is None: + tables_md_lines: List[str] = [] + ft_tables = row.get("fulltext_tables") + if isinstance(ft_tables, str): + try: + ft_tables = json.loads(ft_tables) + except Exception: + ft_tables = None + if isinstance(ft_tables, list): + for item in ft_tables: + if not isinstance(item, dict): + continue + idx = item.get("index") + blob_addr = item.get("blob_address") + caption = item.get("caption") + if not idx or not blob_addr: + continue + try: + md_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) + md_txt = md_bytes.decode("utf-8", errors="replace") + header = f"Table [T{idx}]" + (f" caption: {caption}" if caption else "") + tables_md_lines.extend([header, md_txt, ""]) + except Exception: + continue + tables_text = "\n".join(tables_md_lines) if tables_md_lines else "(none)" + + # figures: captions + optionally images + if figures_text is None: + figures_lines: List[str] = [] + ft_figs = row.get("fulltext_figures") + if isinstance(ft_figs, str): + try: + ft_figs = json.loads(ft_figs) + except Exception: + ft_figs = None + if isinstance(ft_figs, list): + for item in ft_figs: + if not isinstance(item, dict): + continue + idx = item.get("index") + blob_addr = item.get("blob_address") + caption = item.get("caption") + if not idx or not blob_addr: + continue + figures_lines.append( + f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})" + ) + if payload.attach_figures: + try: + img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) + if img_bytes: + images.append((img_bytes, "image/png")) + except Exception: + continue + figures_text = "\n".join(figures_lines) if figures_lines else "(none)" + + tables_text = tables_text if tables_text is not None else "(none)" + figures_text = figures_text if figures_text is not None else "(none)" + # Build prompt prompt = PARAMETER_PROMPT_JSON.format( parameter_name=payload.parameter_name, parameter_description=payload.parameter_description, - fulltext=fulltext + fulltext=fulltext, + tables=tables_text, + figures=figures_text, ) if not azure_openai_client.is_configured(): raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server") try: - llm_response = await azure_openai_client.simple_chat( - user_message=prompt, - system_prompt=None, - model=payload.model, - max_tokens=payload.max_tokens or 512, - temperature=payload.temperature or 0.0, - ) + if images: + llm_response = await azure_openai_client.multimodal_chat( + user_text=prompt, + images=images, + system_prompt=None, + model=payload.model, + max_tokens=payload.max_tokens or 512, + temperature=payload.temperature or 0.0, + ) + else: + llm_response = await azure_openai_client.simple_chat( + user_message=prompt, + system_prompt=None, + model=payload.model, + max_tokens=payload.max_tokens or 512, + temperature=payload.temperature or 0.0, + ) except Exception as e: raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"LLM call failed: {e}") @@ -216,12 +309,41 @@ def _extract_json_object(text: str) -> Optional[str]: else: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="'evidence_sentences' must be a list") + # Normalize evidence tables/figures + def _norm_int_list(v: Any) -> List[int]: + if v is None: + return [] + if not isinstance(v, list): + return [] + out: List[int] = [] + for item in v: + if isinstance(item, int): + out.append(item) + elif isinstance(item, str): + try: + out.append(int(item.strip())) + except Exception: + continue + # stable unique + seen = set() + uniq: List[int] = [] + for x in out: + if x not in seen: + seen.add(x) + uniq.append(x) + return uniq + + evidence_tables = _norm_int_list(parsed.get("evidence_tables")) + evidence_figures = _norm_int_list(parsed.get("evidence_figures")) + # Build the stored JSON stored = { "found": found_val, "value": val, "explanation": explanation, "evidence_sentences": evidence, + "evidence_tables": evidence_tables, + "evidence_figures": evidence_figures, "llm_raw": llm_response[:4000] } @@ -229,7 +351,7 @@ def _extract_json_object(text: str) -> Optional[str]: col_name = snake_case_param(payload.parameter_name) try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, stored, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -256,9 +378,8 @@ async def human_extract_parameter( and does not call any LLM. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -268,7 +389,7 @@ async def human_extract_parameter( # Ensure citation exists (we won't require full_text for human input but check row presence) try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -323,7 +444,7 @@ async def human_extract_parameter( col_name = f"human_param_{core}" if core else "human_param_param" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, stored, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -347,9 +468,8 @@ async def extract_fulltext_from_storage( under column "fulltext", and return the generated fulltext_str. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -359,7 +479,7 @@ async def extract_fulltext_from_storage( # fetch citation row try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -373,16 +493,14 @@ async def extract_fulltext_from_storage( if not storage_path: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No fulltext storage path found on citation row") - if "/" not in storage_path: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognized storage path format") - - container, blob = storage_path.split("/", 1) - try: - blob_client = storage_service.blob_service_client.get_blob_client(container=container, blob=blob) - content = blob_client.download_blob().readall() + content, _filename = await storage_service.get_bytes_by_path(storage_path) + except FileNotFoundError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Fulltext file not found in storage") + except ValueError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognized storage path format") except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download blob: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download from storage: {e}") # If the citation row already contains an extracted full text in the "fulltext" column, # only use it if the stored md5 matches the pdf we just downloaded. @@ -404,24 +522,142 @@ async def extract_fulltext_from_storage( "cached": True, } - # write to temp file and call grobid + # write to temp file and call grobid (+ Azure DI in parallel) tmp = NamedTemporaryFile(delete=False, suffix=".pdf") try: tmp.write(content) tmp.flush() tmp.close() - # process with grobid + # process with grobid (sentence coords) and Azure DI (tables/figures) in parallel + async def _run_grobid(): + return await grobid_service.process_structure(tmp.name) + + async def _run_docint(): + if not azure_docint_client or not azure_docint_client.is_available(): + return {"success": False, "error": "Azure DI not configured", "figures": [], "tables": []} + return await azure_docint_client.extract_citation_artifacts(tmp.name, source_type="file") + try: - coords, pages = await grobid_service.process_structure(tmp.name) + (coords, pages), docint_res = await asyncio.gather( + _run_grobid(), + _run_docint(), + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Grobid processing failed: {e}") + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Fulltext processing failed: {e}") # filter sentence annotations annotations = [a for a in coords if a.get("type") == "s" and a.get("text")] full_text_arr = _list_set([a["text"] for a in annotations]) full_text_str = "\n\n".join([f"[{i}] {x}" for i, x in enumerate(full_text_arr)]) + # ------------------------- + # Upload Azure DI artifacts + # ------------------------- + fulltext_figures: List[Dict[str, Any]] = [] + fulltext_tables: List[Dict[str, Any]] = [] + artifact_coords: List[Dict[str, Any]] = [] + + try: + if docint_res and isinstance(docint_res, dict) and docint_res.get("success"): + pages_meta = docint_res.get("pages") or [] + # Determine artifact base path from the fulltext_url directory + # storage_path is "container/blob". + container, blob = storage_path.split("/", 1) + blob_dir = str(Path(blob).parent).replace("\\", "/") + artifacts_prefix = f"{container}/{blob_dir}/{citation_id}_artifacts" + artifacts_prefix = artifacts_prefix.replace("//", "/").rstrip("/") + + # Figures: write png + for fig in (docint_res.get("figures") or []): + try: + idx = int(fig.get("index")) + except Exception: + continue + artifact_id = f"figure_{idx}.png" + blob_address = f"{artifacts_prefix}/{artifact_id}" + png_bytes = fig.get("png_bytes") or b"" + caption = fig.get("caption") + bbox = fig.get("bounding_box") + boxes = normalize_bounding_regions_to_boxes(bbox, pages_meta) + + # Upload only if we actually got image bytes + if png_bytes: + await storage_service.put_bytes_by_path( + blob_address, + png_bytes, + content_type="image/png", + ) + + fulltext_figures.append( + { + "blob_address": blob_address, + "caption": caption, + "bounding_box": boxes, + "description": None, + "index": idx, + } + ) + + # Also add to overlay coords so the existing PDF viewer logic can + # highlight tables/figures using the same shape as Grobid coords. + for b in boxes or []: + if not isinstance(b, dict): + continue + artifact_coords.append( + { + **b, + "type": "figure", + "artifact_index": idx, + "text": f"Figure F{idx}", + } + ) + + # Tables: write markdown (.md) + for tbl in (docint_res.get("tables") or []): + try: + idx = int(tbl.get("index")) + except Exception: + continue + artifact_id = f"table_{idx}.md" + blob_address = f"{artifacts_prefix}/{artifact_id}" + md = (tbl.get("table_markdown") or "").encode("utf-8") + caption = tbl.get("caption") + bbox = tbl.get("bounding_box") + boxes = normalize_bounding_regions_to_boxes(bbox, pages_meta) + if md: + await storage_service.put_bytes_by_path( + blob_address, + md, + content_type="text/markdown", + ) + fulltext_tables.append( + { + "blob_address": blob_address, + "caption": caption, + "bounding_box": boxes, + "description": None, + "index": idx, + } + ) + + for b in boxes or []: + if not isinstance(b, dict): + continue + artifact_coords.append( + { + **b, + "type": "table", + "artifact_index": idx, + "text": f"Table T{idx}", + } + ) + except Exception: + # Best-effort; DI artifacts should not block fulltext extraction. + fulltext_figures = [] + fulltext_tables = [] + artifact_coords = [] + finally: try: os.unlink(tmp.name) @@ -430,11 +666,14 @@ async def extract_fulltext_from_storage( # persist full_text_str and coordinates/pages into citation row try: - updated1 = await run_in_threadpool(cits_dp_service.update_text_column, db_conn, citation_id, "fulltext", full_text_str, table_name) - updated2 = await run_in_threadpool(cits_dp_service.update_text_column, db_conn, citation_id, "fulltext_md5", current_md5, table_name) - updated3 = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, "fulltext_coords", annotations, table_name) - updated4 = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, "fulltext_pages", pages, table_name) - updated = updated1 or updated2 or updated3 or updated4 + coords_for_overlay = list(annotations) + list(artifact_coords) + updated1 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext", full_text_str, table_name) + updated2 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext_md5", current_md5, table_name) + updated3 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_coords", coords_for_overlay, table_name) + updated4 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_pages", pages, table_name) + updated5 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_figures", fulltext_figures, table_name) + updated6 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_tables", fulltext_tables, table_name) + updated = updated1 or updated2 or updated3 or updated4 or updated5 or updated6 except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -443,4 +682,12 @@ async def extract_fulltext_from_storage( if not updated: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") - return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "fulltext": full_text_str, "n_pages": len(pages)} + return { + "status": "success", + "sr_id": sr_id, + "citation_id": citation_id, + "fulltext": full_text_str, + "n_pages": len(pages), + "fulltext_figures": fulltext_figures, + "fulltext_tables": fulltext_tables, + } diff --git a/backend/api/files/router.py b/backend/api/files/router.py index 107d565e..b3c769fe 100644 --- a/backend/api/files/router.py +++ b/backend/api/files/router.py @@ -5,7 +5,6 @@ import os import logging from datetime import datetime, timezone -from azure.core.exceptions import ResourceNotFoundError from fastapi import ( APIRouter, @@ -206,9 +205,10 @@ async def download_by_path( path: str, current_user: Dict[str, Any] = Depends(get_current_active_user), ): - """ - Stream a blob directly by storage path in the form 'container/blob_path' - Example storage_path: "container/users/123/docid_filename.pdf" + """Download directly by storage path in the form 'container/blob_path'. + + This works for both Azure and local storage backends. + Example: "users/users/123/documents/_.pdf" """ try: if not storage_service: @@ -217,29 +217,14 @@ async def download_by_path( detail="Storage service not available", ) - # Basic validation: expect "container/blob" - if not path or "/" not in path: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid storage path" - ) - - container, blob = path.split("/", 1) - logger.info(f"Downloading blob: container={container}, blob={blob}") - try: - blob_client = storage_service.blob_service_client.get_blob_client( - container=container, blob=blob - ) - stream = blob_client.download_blob() - content = stream.readall() - except ResourceNotFoundError: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Blob not found" - ) - - filename = os.path.basename(blob) or "download" + content, filename = await storage_service.get_bytes_by_path(path) + except FileNotFoundError: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + except ValueError: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path") + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download: {e}") def gen(): yield content diff --git a/backend/api/screen/prompts.py b/backend/api/screen/prompts.py index 1d10e20e..b402f3e0 100644 --- a/backend/api/screen/prompts.py +++ b/backend/api/screen/prompts.py @@ -41,22 +41,34 @@ - Full text (numbered sentences): {fulltext} +- Tables (numbered): +{tables} + +- Figures (numbered; captions correspond to images provided alongside this message): +{figures} + Respond with a JSON object containing these keys: - "selected": the exact option string you selected (must match one of the options above; if none fits, pick the closest option and report a low confidence score) - "explanation": a concise explanation (1-4 sentences) of why you selected that option - "confidence": a floating number between 0 and 1 (inclusive) representing your estimated confidence for the selected option - "evidence_sentences": an array of integers indicating the sentence indices you used as evidence (e.g. [2, 5]). If there is low confidence, return an empty array []. +- "evidence_tables": an array of integers indicating the table numbers you used (e.g. [1, 3]) or [] if none. +- "evidence_figures": an array of integers indicating the figure numbers you used (e.g. [2]) or [] if none. JSON object format: {{ "selected": "", "explanation": "<1-4 sentences explaining the choice>", "confidence": , - "evidence_sentences": [] + "evidence_sentences": [], + "evidence_tables": [], + "evidence_figures": [
] }} Notes: - Keep the response strictly as a JSON object that matches the schema above. - Do not wrap the response in Markdown code fences or add language tags (e.g., ```json). Return only raw JSON. - Use sentence indices from the numbered full text for "evidence_sentences" +- Use table numbers from the Tables section for "evidence_tables" +- Use figure numbers from the Figures section for "evidence_figures" """ \ No newline at end of file diff --git a/backend/api/screen/router.py b/backend/api/screen/router.py index 20fc9986..a36c7c86 100644 --- a/backend/api/screen/router.py +++ b/backend/api/screen/router.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import json import re from datetime import datetime @@ -11,6 +11,7 @@ from ..core.config import settings from ..core.security import get_current_active_user from ..services.azure_openai_client import azure_openai_client +from ..services.storage import storage_service # Import helpers from citations router to fetch citation rows and build combined citations from ..citations import router as citations_router @@ -25,6 +26,32 @@ router = APIRouter() +def _normalize_int_list(v: Any) -> List[int]: + if v is None: + return [] + if not isinstance(v, list): + return [] + out: List[int] = [] + for item in v: + if isinstance(item, int): + out.append(item) + elif isinstance(item, str): + try: + out.append(int(item.strip())) + except Exception: + continue + else: + continue + # stable unique + seen = set() + uniq: List[int] = [] + for x in out: + if x not in seen: + seen.add(x) + uniq.append(x) + return uniq + + class ClassifyRequest(BaseModel): citation_text: Optional[str] = Field( None, description="Optional combined citation text. If omitted the server will build it from the screening DB row." @@ -79,9 +106,8 @@ async def classify_citation( The 'selected' field in the returned JSON is validated against the provided `options`. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -91,7 +117,7 @@ async def classify_citation( # Load citation row (needed for l2 fulltext and for building citation_text) try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -103,30 +129,106 @@ async def classify_citation( # Build or use provided citation text (fall back to combined title/abstract when not provided) citation_text = payload.citation_text or citations_router._build_combined_citation_from_row(row, payload.include_columns) + # Ensure LLM client is available + if not azure_openai_client.is_configured(): + raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server") + # Prepare prompt (use full-text template for l2, otherwise TA/L1 template) options_listed = "\n".join([f"{i}. {opt}" for i, opt in enumerate(payload.options)]) + llm_response: str + if (payload.screening_step or "").lower() == "l2": fulltext = row.get("fulltext") or "" + + # Tables/Figures context from citation row (populated by extract-fulltext) + tables_md_lines: List[str] = [] + figures_lines: List[str] = [] + images: List[Tuple[bytes, str]] = [] + + # Tables: fetch markdown blobs and embed + ft_tables = row.get("fulltext_tables") + if isinstance(ft_tables, str): + try: + ft_tables = json.loads(ft_tables) + except Exception: + ft_tables = None + if isinstance(ft_tables, list): + for item in ft_tables: + if not isinstance(item, dict): + continue + idx = item.get("index") + blob_addr = item.get("blob_address") + caption = item.get("caption") + if not idx or not blob_addr: + continue + try: + md_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) + except Exception: + continue + md_txt = md_bytes.decode("utf-8", errors="replace") + header = f"Table [T{idx}]" + (f" caption: {caption}" if caption else "") + tables_md_lines.extend([header, md_txt, ""]) + + # Figures: fetch png blobs and attach as images + ft_figs = row.get("fulltext_figures") + if isinstance(ft_figs, str): + try: + ft_figs = json.loads(ft_figs) + except Exception: + ft_figs = None + if isinstance(ft_figs, list): + for item in ft_figs: + if not isinstance(item, dict): + continue + idx = item.get("index") + blob_addr = item.get("blob_address") + caption = item.get("caption") + if not idx or not blob_addr: + continue + figures_lines.append( + f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})" + ) + try: + img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) + if img_bytes: + images.append((img_bytes, "image/png")) + except Exception: + continue + prompt = PROMPT_JSON_TEMPLATE_FULLTEXT.format( question=payload.question, options=options_listed, xtra=payload.xtra or "", fulltext=fulltext or citation_text, + tables="\n".join(tables_md_lines) if tables_md_lines else "(none)", + figures="\n".join(figures_lines) if figures_lines else "(none)", ) + + # Prefer multimodal when figures are present + if images: + llm_response = await azure_openai_client.multimodal_chat( + user_text=prompt, + images=images, + system_prompt=None, + model=payload.model, + max_tokens=payload.max_tokens or 2000, + temperature=payload.temperature or 0.0, + ) + else: + llm_response = await azure_openai_client.simple_chat( + user_message=prompt, + system_prompt=None, + model=payload.model, + max_tokens=payload.max_tokens or 2000, + temperature=payload.temperature or 0.0, + ) else: prompt = PROMPT_JSON_TEMPLATE.format( question=payload.question, cit=citation_text, options=options_listed, - xtra=payload.xtra or "" + xtra=payload.xtra or "", ) - - # Ensure LLM client is available - if not azure_openai_client.is_configured(): - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server") - - # Call the LLM (simple non-streaming chat) - try: llm_response = await azure_openai_client.simple_chat( user_message=prompt, system_prompt=None, @@ -134,8 +236,6 @@ async def classify_citation( max_tokens=payload.max_tokens or 2000, temperature=payload.temperature or 0.0, ) - except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"LLM call failed: {e}") # Parse JSON (assume valid JSON) - try/except only logger.info(llm_response) @@ -173,22 +273,17 @@ async def classify_citation( except Exception: confidence = 0.0 - # Normalize evidence sentences from model (optional) - evidence_raw = parsed.get("evidence_sentences", []) - evidence: List[int] = [] - if isinstance(evidence_raw, list): - for i in evidence_raw: - try: - idx = int(i) - evidence.append(idx) - except Exception: - continue + evidence = _normalize_int_list(parsed.get("evidence_sentences")) + evidence_tables = _normalize_int_list(parsed.get("evidence_tables")) + evidence_figures = _normalize_int_list(parsed.get("evidence_figures")) classification_json = { "selected": resolved_selected, "explanation": explanation, "confidence": confidence, "evidence_sentences": evidence, + "evidence_tables": evidence_tables, + "evidence_figures": evidence_figures, "llm_raw": llm_response, # raw response for audit } @@ -196,7 +291,7 @@ async def classify_citation( col_name = snake_case_column(payload.question) try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, classification_json, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -205,7 +300,7 @@ async def classify_citation( if not updated: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") - await update_inclusion_decision(sr, citation_id, payload.screening_step, "llm", db_conn) + await update_inclusion_decision(sr, citation_id, payload.screening_step, "llm") return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json} @@ -223,9 +318,8 @@ async def human_classify_citation( The column name is prefixed with 'human_' to distinguish from automated classifications. """ - db_conn_str = settings.POSTGRES_URI try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -235,7 +329,7 @@ async def human_classify_citation( # Ensure citation exists and optionally build combined citation text try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -263,7 +357,7 @@ async def human_classify_citation( col_name = f"human_{col_core}" if col_core else f"human_col" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, db_conn, citation_id, col_name, classification_json, table_name) + updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -272,7 +366,7 @@ async def human_classify_citation( if not updated: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") - await update_inclusion_decision(sr, citation_id, payload.screening_step, "human", db_conn) + await update_inclusion_decision(sr, citation_id, payload.screening_step, "human") return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json} @@ -281,13 +375,11 @@ async def update_inclusion_decision( citation_id: int, screening_step: str, decision_maker: str, - db_conn: str, - ): table_name = (sr.get("screening_db") or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, db_conn, int(citation_id), table_name) + row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: @@ -312,7 +404,7 @@ async def update_inclusion_decision( col_name = f"{decision_maker}_{screening_step}_decision" try: - updated = await run_in_threadpool(cits_dp_service.update_text_column, db_conn, citation_id, col_name, decision, table_name) + updated = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, col_name, decision, table_name) except RuntimeError as rexc: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) except Exception as e: diff --git a/backend/api/services/azure_docint_client.py b/backend/api/services/azure_docint_client.py new file mode 100644 index 00000000..4c047a4a --- /dev/null +++ b/backend/api/services/azure_docint_client.py @@ -0,0 +1,822 @@ +""" +Azure Document Intelligence Service for PDF processing + +This service uses Azure's Document Intelligence to extract +structured content from PDF documents. It provides an alternative to Docling for +document processing with potentially better handling of complex layouts, tables, +and structured documents. + +Key Features: +- Superior table extraction +- Form field recognition +- Multi-language support +- Layout analysis with reading order +- Handwriting recognition +- Key-value pair extraction +- Figure/chart detection and extraction with captions +- Downloadable cropped figure images +""" + +import base64 +import os +import uuid +import json +import asyncio +from datetime import datetime +from pathlib import Path +from typing import Dict, Any, Optional, List, Tuple + +from bs4 import BeautifulSoup +from typing import TYPE_CHECKING + +try: + from azure.ai.documentintelligence import DocumentIntelligenceClient + from azure.ai.documentintelligence.models import AnalyzeDocumentRequest + from azure.core.credentials import AzureKeyCredential + + AZURE_DOC_INTELLIGENCE_AVAILABLE = True +except ImportError: + AZURE_DOC_INTELLIGENCE_AVAILABLE = False + # Keep type names defined so type annotations don't crash at runtime when SDK isn't installed. + from typing import Any as DocumentIntelligenceClient # type: ignore + from typing import Any as AnalyzeDocumentRequest # type: ignore + from typing import Any as AzureKeyCredential # type: ignore + print( + "Azure Document Intelligence SDK not installed. Install with: pip install azure-ai-documentintelligence azure-core" + ) + + +class AzureDocIntelligenceService: + """Service for processing documents using Azure Document Intelligence""" + + def __init__(self): + self.base_path = Path(__file__).parent.parent.parent.parent.parent + # Unified directory structure: output/azure_doc_intelligence/{conversion_id}/ + self.output_base_dir = self.base_path / "output" / "azure_doc_intelligence" + self.output_base_dir.mkdir(parents=True, exist_ok=True) + + # Initialize Azure client + self.client = self._init_client() + + def _init_client(self) -> Optional["DocumentIntelligenceClient"]: + """Initialize Azure Document Intelligence client""" + if not AZURE_DOC_INTELLIGENCE_AVAILABLE: + return None + + endpoint = os.getenv("AZURE_DOC_INTELLIGENCE_ENDPOINT") + key = os.getenv("AZURE_DOC_INTELLIGENCE_KEY") + + if not endpoint or not key: + print( + "Azure Document Intelligence credentials not found. Set AZURE_DOC_INTELLIGENCE_ENDPOINT and AZURE_DOC_INTELLIGENCE_KEY" + ) + return None + + try: + return DocumentIntelligenceClient( + endpoint=endpoint, credential=AzureKeyCredential(key) + ) + except Exception as e: + print(f"Failed to initialize Azure Document Intelligence client: {e}") + return None + + def _analyze_document_sync( + self, source: str, source_type: str, output_param: Optional[List[str]] + ): + """Synchronous wrapper for document analysis to run in thread pool""" + if source_type == "file": + with open(source, "rb") as f: + file_content = f.read() + # Use the correct API format + poller = self.client.begin_analyze_document( + model_id="prebuilt-layout", + body=file_content, + content_type="application/octet-stream", + output_content_format="markdown", + output=output_param, + ) + else: # URL + # For URL, use AnalyzeDocumentRequest + analyze_request = AnalyzeDocumentRequest(url_source=source) + poller = self.client.begin_analyze_document( + model_id="prebuilt-layout", + analyze_request=analyze_request, + output_content_format="markdown", + output=output_param, + ) + + # Wait for completion (this is the blocking part) + result = poller.result() + + # Extract result_id from the poller + result_id = None + try: + # Extract result_id from operation-location header in initial response + if hasattr(poller, "_polling_method") and hasattr( + poller._polling_method, "_initial_response" + ): + initial_resp = poller._polling_method._initial_response + if hasattr(initial_resp, "http_response") and hasattr( + initial_resp.http_response, "headers" + ): + headers = initial_resp.http_response.headers + # Try different header names (Azure API uses 'operation-location') + for header_name in [ + "operation-location", + "Operation-Location", + ]: + if header_name in headers: + operation_location = headers[header_name] + # Extract result_id from URL: .../analyzeResults/{result_id}?api-version=... + if "analyzeResults" in operation_location: + result_id = operation_location.split("/")[-1].split( + "?" + )[0] + break + except Exception as e: + print(f"Warning: Could not extract result_id: {str(e)}") + + return result, result_id + + async def convert_document_to_markdown( + self, + source: str, + source_type: str = "file", + extract_figures: bool = True, + output_dir: Optional[Path] = None, + ) -> Dict[str, Any]: + """ + Convert document to markdown using Azure Document Intelligence (Non-blocking) + + Args: + source: File path or URL to the document + source_type: Type of source ("file" or "url") + extract_figures: Whether to extract figures + output_dir: Optional output directory. If provided, saves directly here. + If None, uses legacy UUID-based path in output/azure_doc_intelligence/ + """ + return await asyncio.to_thread( + self._convert_document_sync, + source, + source_type, + extract_figures, + output_dir, + ) + + def _convert_document_sync( + self, + source: str, + source_type: str = "file", + extract_figures: bool = True, + output_dir: Optional[Path] = None, + ) -> Dict[str, Any]: + """ + Synchronous implementation of document conversion + + Args: + output_dir: If provided, save output directly to this directory. + If None, uses legacy UUID-based path. + """ + if not self.client: + return { + "success": False, + "error": "Azure Document Intelligence client not available", + "conversion_id": str(uuid.uuid4()), + } + + conversion_id = str(uuid.uuid4()) + start_time = datetime.now() + + try: + # Use provided output_dir or fall back to legacy UUID-based path + if output_dir: + conversion_dir = output_dir + conversion_dir.mkdir(parents=True, exist_ok=True) + else: + # Legacy: output/azure_doc_intelligence/{conversion_id}/ + conversion_dir = self.output_base_dir / conversion_id + conversion_dir.mkdir(parents=True, exist_ok=True) + + # Define all file paths within the conversion directory + log_path = conversion_dir / "conversion.log" + raw_json_path = conversion_dir / "raw_analysis.json" + markdown_path = conversion_dir / "document.md" + metadata_path = conversion_dir / "metadata.json" + figures_dir = conversion_dir / "figures" + + self._log_sync( + log_path, f"Starting Azure Document Intelligence conversion: {source}" + ) + + # Analyze document - Updated API format + # Include 'figures' in output if extract_figures is True + output_param = ["figures"] if extract_figures else None + + self._log_sync( + log_path, "Document analysis started (in background thread)..." + ) + + # Run the blocking analysis directly since we are already in a thread + result, result_id = self._analyze_document_sync( + source, + source_type, + output_param, + ) + + self._log_sync(log_path, "Document analysis completed") + + # Convert full result to dictionary for JSON serialization (with ALL bounding boxes) + result_dict = result.as_dict() + + # Add processor field for format detection + result_dict["processor"] = "azure_doc_intelligence" + + # Save the FULL raw JSON response + with open(raw_json_path, "w", encoding="utf-8") as f: + json.dump(result_dict, f, indent=2, ensure_ascii=False) + self._log_sync( + log_path, + f"Saved full raw JSON with bounding boxes to raw_analysis.json", + ) + + # Extract markdown content + markdown_content = result.content if result.content else "" + + # Save markdown + with open(markdown_path, "w", encoding="utf-8") as f: + f.write(markdown_content) + + # Extract and save tables as separate HTML files + self._extract_and_save_tables_sync( + result=result, + conversion_dir=conversion_dir, + markdown_content=markdown_content, + log_path=log_path, + ) + + # Process figures if requested + figures_metadata = [] + + if extract_figures and result.figures: + figures_dir.mkdir(parents=True, exist_ok=True) + self._log_sync( + log_path, f"Found {len(result.figures)} figures to process" + ) + + if result_id: + self._log_sync( + log_path, + f"Extracted result_id: {result_id}", + ) + else: + self._log_sync( + log_path, + "⚠️ Could not extract result_id - figure images will not be downloaded", + ) + + for idx, figure in enumerate(result.figures): + figure_id = ( + figure.id + if hasattr(figure, "id") and figure.id + else f"unknown_{idx}" + ) + + # Extract figure metadata + figure_info = { + "id": figure_id, + "page": ( + figure.bounding_regions[0].page_number + if figure.bounding_regions + else None + ), + "caption": ( + figure.caption.content + if hasattr(figure, "caption") and figure.caption + else None + ), + "spans": ( + [ + {"offset": span.offset, "length": span.length} + for span in figure.spans + ] + if figure.spans + else [] + ), + "bounding_regions": ( + [ + { + "page_number": region.page_number, + "polygon": region.polygon, + } + for region in figure.bounding_regions + ] + if figure.bounding_regions + else [] + ), + } + + # Download the figure image if we have a result_id + if result_id: + image_path = self._download_figure_sync( + result_id=result_id, + figure_id=figure_id, + figures_dir=figures_dir, + log_path=log_path, + ) + if image_path: + figure_info["image_path"] = image_path + else: + self._log_sync( + log_path, + f"Skipping image download for figure {figure_id} (no result_id)", + ) + + figures_metadata.append(figure_info) + + self._log_sync(log_path, f"Processed {len(figures_metadata)} figures") + + # Create metadata + end_time = datetime.now() + conversion_time = (end_time - start_time).total_seconds() + + metadata = { + "conversion_id": conversion_id, + "source": source, + "source_type": source_type, + "processor": "azure_doc_intelligence", + "model_id": "prebuilt-layout", + "status": "success", + "conversion_dir": str(conversion_dir), + "markdown_path": str(markdown_path), + "raw_json_path": str(raw_json_path), + "log_path": str(log_path), + "start_time": start_time.isoformat(), + "end_time": end_time.isoformat(), + "conversion_time": conversion_time, + "content_length": len(markdown_content), + "page_count": len(result.pages) if result.pages else 0, + "tables_found": len(result.tables) if result.tables else 0, + "key_value_pairs_found": ( + len(result.key_value_pairs) if result.key_value_pairs else 0 + ), + "figures_found": len(figures_metadata), + "figures": figures_metadata, + } + + # Save metadata + with open(metadata_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + + self._log_sync( + log_path, f"Conversion completed successfully in {conversion_time:.2f}s" + ) + self._log_sync(log_path, f"Pages processed: {metadata['page_count']}") + self._log_sync(log_path, f"Tables found: {metadata['tables_found']}") + self._log_sync( + log_path, f"Key-value pairs found: {metadata['key_value_pairs_found']}" + ) + self._log_sync(log_path, f"Figures found: {metadata['figures_found']}") + + return { + "success": True, + "conversion_id": conversion_id, + "markdown_path": str(markdown_path), + "metadata": { + "content_length": len(markdown_content), + "conversion_time": conversion_time, + "page_count": metadata["page_count"], + "tables_found": metadata["tables_found"], + "key_value_pairs_found": metadata["key_value_pairs_found"], + "figures_found": metadata["figures_found"], + "figures": figures_metadata, + }, + } + + except Exception as e: + error_msg = f"Azure Document Intelligence conversion failed: {str(e)}" + + # Try to log error if log_path was created + try: + self._log_sync(log_path, f"ERROR: {error_msg}") + except: + pass + + # Save error metadata + metadata = { + "conversion_id": conversion_id, + "source": source, + "source_type": source_type, + "processor": "azure_doc_intelligence", + "status": "error", + "error": error_msg, + "conversion_dir": ( + str(conversion_dir) if "conversion_dir" in locals() else None + ), + "log_path": str(log_path) if "log_path" in locals() else None, + "start_time": start_time.isoformat(), + "end_time": datetime.now().isoformat(), + } + + # Save metadata to conversion directory if it exists + try: + if "metadata_path" in locals(): + with open(metadata_path, "w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2) + except: + pass + + return { + "success": False, + "error": error_msg, + "conversion_id": conversion_id, + } + + async def _log(self, log_path: Path, message: str): + """Write log message (Async wrapper)""" + await asyncio.to_thread(self._log_sync, log_path, message) + + def _log_sync(self, log_path: Path, message: str): + """Write log message (Synchronous)""" + timestamp = datetime.now().isoformat() + log_entry = f"[{timestamp}] {message}\n" + + with open(log_path, "a", encoding="utf-8") as f: + f.write(log_entry) + + def _extract_and_save_tables_sync( + self, + result: Any, + conversion_dir: Path, + markdown_content: str, + log_path: Path, + ) -> None: + """ + Extract tables from Azure result and save as separate HTML files + + Args: + result: Azure Document Intelligence result + conversion_dir: Conversion directory + markdown_content: The markdown content containing HTML tables + log_path: Log file path + """ + import re + + if not result.tables or len(result.tables) == 0: + self._log_sync(log_path, "No tables found in document") + return + + # Create tables directory + tables_dir = conversion_dir / "tables" + tables_dir.mkdir(parents=True, exist_ok=True) + + self._log_sync(log_path, f"Found {len(result.tables)} tables to extract") + + # Extract HTML tables from markdown content using regex + # Match
...
blocks + table_pattern = r".*?
" + html_tables = re.findall(table_pattern, markdown_content, re.DOTALL) + + # Save each table as a separate HTML file + for idx, html_table in enumerate(html_tables, start=1): + try: + table_html_path = tables_dir / f"table-{idx}.html" + with open(table_html_path, "w", encoding="utf-8") as f: + f.write(html_table) + self._log_sync(log_path, f"Saved table {idx} to {table_html_path.name}") + except Exception as e: + self._log_sync(log_path, f"Failed to save table {idx}: {str(e)}") + + self._log_sync( + log_path, f"Extracted {len(html_tables)} tables to tables/ directory" + ) + + def _download_figure_sync( + self, result_id: str, figure_id: str, figures_dir: Path, log_path: Path + ) -> Optional[str]: + """ + Download a single figure image from Azure Document Intelligence + + Args: + result_id: The analysis result ID + figure_id: The figure ID (e.g., "1.1" for page 1, figure 1) + figures_dir: Directory to save the figure + log_path: Log file path + + Returns: + Relative path to the saved figure or None if failed + """ + try: + # Get the figure using the SDK's get_analyze_result_figure method + figure_stream = self.client.get_analyze_result_figure( + model_id="prebuilt-layout", result_id=result_id, figure_id=figure_id + ) + + # Save the figure + figure_filename = f"{figure_id}.png" + figure_path = figures_dir / figure_filename + + with open(figure_path, "wb") as f: + for chunk in figure_stream: + f.write(chunk) + + self._log_sync( + log_path, f"Downloaded figure {figure_id} to {figure_filename}" + ) + return f"figures/{figure_filename}" + + except Exception as e: + self._log_sync(log_path, f"Failed to download figure {figure_id}: {str(e)}") + return None + + async def get_conversion_by_id( + self, conversion_id: str + ) -> Optional[Dict[str, Any]]: + """Get conversion metadata by ID""" + metadata_path = self.output_base_dir / conversion_id / "metadata.json" + + if not metadata_path.exists(): + return None + + try: + with open(metadata_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + + async def get_markdown_content(self, conversion_id: str) -> Optional[str]: + """ + Get markdown content by conversion ID for LLM entity extraction + + Returns the markdown content from document.md + """ + conversion_dir = self.output_base_dir / conversion_id + + markdown_path = conversion_dir / "document.md" + if not markdown_path.exists(): + return None + + try: + with open(markdown_path, "r", encoding="utf-8") as f: + return f.read() + except Exception: + return None + + def is_available(self) -> bool: + """Check if Azure Document Intelligence is available and configured""" + return AZURE_DOC_INTELLIGENCE_AVAILABLE and self.client is not None + + async def get_figures_for_conversion( + self, conversion_id: str + ) -> Optional[List[Dict[str, Any]]]: + """ + Get all figures metadata for a specific conversion + + Args: + conversion_id: The conversion ID + + Returns: + List of figure metadata dictionaries or None if not found + """ + metadata = await self.get_conversion_by_id(conversion_id) + if metadata and "figures" in metadata: + return metadata["figures"] + return None + + async def get_raw_analysis_result( + self, conversion_id: str + ) -> Optional[Dict[str, Any]]: + """ + Get the complete raw JSON analysis result with ALL bounding boxes + + This includes: + - All pages with words, lines, spans, selection marks + - All paragraphs with bounding regions and roles + - All tables with cells and bounding boxes + - All figures with bounding regions + - All sections and structural information + + Args: + conversion_id: The conversion ID + + Returns: + Complete analysis result dictionary or None if not found + """ + raw_json_path = self.output_base_dir / conversion_id / "raw_analysis.json" + + if not raw_json_path.exists(): + return None + + try: + with open(raw_json_path, "r", encoding="utf-8") as f: + return json.load(f) + except Exception: + return None + + # ------------------------------------------------------------------ + # CAN-SR helpers (artifacts for citations) + # ------------------------------------------------------------------ + + @staticmethod + def _html_table_to_markdown(html_table: str) -> str: + """Convert a single ...
block into GitHub-flavored markdown. + + Best-effort conversion for prompt inclusion. + """ + soup = BeautifulSoup(html_table, "html.parser") + table = soup.find("table") + if not table: + return html_table + + rows: List[List[str]] = [] + for tr in table.find_all("tr"): + cells = tr.find_all(["th", "td"]) + row = [" ".join(c.get_text(" ", strip=True).split()) for c in cells] + if row: + rows.append(row) + + if not rows: + return "" + + # Normalize row widths + width = max(len(r) for r in rows) + rows = [r + [""] * (width - len(r)) for r in rows] + + header = rows[0] + body = rows[1:] if len(rows) > 1 else [] + + def esc(v: str) -> str: + return (v or "").replace("|", "\\|") + + md_lines = [ + "| " + " | ".join(esc(x) for x in header) + " |", + "| " + " | ".join(["---"] * width) + " |", + ] + for r in body: + md_lines.append("| " + " | ".join(esc(x) for x in r) + " |") + return "\n".join(md_lines) + + @staticmethod + def _extract_html_tables_from_markdown(markdown: str) -> List[str]: + """Extract blocks from Azure markdown content.""" + import re + + if not markdown: + return [] + table_pattern = r"
.*?
" + return re.findall(table_pattern, markdown, re.DOTALL) + + def _download_figure_bytes_sync(self, result_id: str, figure_id: str) -> Optional[bytes]: + """Download a single figure image from Azure DI as bytes (sync).""" + try: + stream = self.client.get_analyze_result_figure( + model_id="prebuilt-layout", result_id=result_id, figure_id=figure_id + ) + chunks: List[bytes] = [] + for chunk in stream: + chunks.append(chunk) + return b"".join(chunks) + except Exception: + return None + + async def extract_citation_artifacts( + self, + source_pdf: str, + source_type: str = "file", + ) -> Dict[str, Any]: + """Run Azure DI and return citation-ready figure/table artifacts. + + Returns: + { + "success": bool, + "error": str|None, + "raw_analysis": dict|None, + "figures": [ + {"index": 1, "azure_id": "1.1", "caption": str|None, "bounding_box": {...}, "png_bytes": bytes} + ], + "tables": [ + {"index": 1, "caption": None, "bounding_box": {...}, "table_markdown": "|...|"} + ] + } + + Notes: + - `description` is intentionally not generated here. + - Bounding boxes are stored as Azure DI polygon regions. + """ + + if not self.client: + return {"success": False, "error": "Azure Document Intelligence client not available"} + + # We need figures in output. + output_param = ["figures"] + + try: + result, result_id = await asyncio.to_thread( + self._analyze_document_sync, source_pdf, source_type, output_param + ) + except Exception as e: + return {"success": False, "error": f"Azure DI analyze failed: {e}"} + + raw_analysis = result.as_dict() if hasattr(result, "as_dict") else {} + raw_analysis["processor"] = "azure_doc_intelligence" + + # Preserve page-level metadata so callers can normalize polygons into + # the same coordinate system as Grobid TEI coords. + # Typical shape: {pageNumber,width,height,unit} + pages_meta = raw_analysis.get("pages") or [] + + markdown_content = getattr(result, "content", None) or "" + + # ----------------- + # Tables + # ----------------- + html_tables = self._extract_html_tables_from_markdown(markdown_content) + md_tables: List[str] = [self._html_table_to_markdown(t) for t in html_tables] + + # Azure gives table bounding regions under raw_analysis['tables'][*]['boundingRegions'] + raw_tables = raw_analysis.get("tables", []) or [] + + tables_out: List[Dict[str, Any]] = [] + for i, md in enumerate(md_tables, start=1): + bbox = None + if i - 1 < len(raw_tables): + bbox = raw_tables[i - 1].get("boundingRegions") or raw_tables[i - 1].get("bounding_regions") + tables_out.append( + { + "index": i, + "caption": None, + "bounding_box": bbox, + "table_markdown": md, + } + ) + + # ----------------- + # Figures + # ----------------- + figures_out: List[Dict[str, Any]] = [] + raw_figures = raw_analysis.get("figures", []) or [] + + # Prefer SDK figures list for caption/bbox, but fall back to raw dict. + sdk_figures = getattr(result, "figures", None) or [] + for idx, fig in enumerate(sdk_figures, start=1): + azure_id = getattr(fig, "id", None) or f"unknown_{idx}" + caption = None + try: + cap = getattr(fig, "caption", None) + caption = getattr(cap, "content", None) if cap else None + except Exception: + caption = None + + bounding_regions = [] + try: + for region in (getattr(fig, "bounding_regions", None) or []): + bounding_regions.append( + {"page_number": region.page_number, "polygon": region.polygon} + ) + except Exception: + bounding_regions = [] + + png_bytes = None + if result_id: + png_bytes = await asyncio.to_thread(self._download_figure_bytes_sync, result_id, azure_id) + + if not png_bytes: + # If we couldn't download, skip storing bytes (still return metadata) + png_bytes = b"" + + figures_out.append( + { + "index": idx, + "azure_id": azure_id, + "caption": caption, + "bounding_box": bounding_regions, + "png_bytes": png_bytes, + } + ) + + # If SDK did not return figures but raw JSON did, include raw ones. + if not figures_out and raw_figures: + for idx, fig in enumerate(raw_figures, start=1): + figures_out.append( + { + "index": idx, + "azure_id": fig.get("id") or f"raw_{idx}", + "caption": (fig.get("caption", {}) or {}).get("content") if isinstance(fig.get("caption"), dict) else None, + "bounding_box": fig.get("boundingRegions") or fig.get("bounding_regions"), + "png_bytes": b"", + } + ) + + return { + "success": True, + "raw_analysis": raw_analysis, + "pages": pages_meta, + "figures": figures_out, + "tables": tables_out, + } + + +# Global instance for routers/services +try: + azure_docint_client = AzureDocIntelligenceService() +except Exception: + azure_docint_client = None # type: ignore \ No newline at end of file diff --git a/backend/api/services/azure_openai_client.py b/backend/api/services/azure_openai_client.py index dd89e685..6481f2f3 100644 --- a/backend/api/services/azure_openai_client.py +++ b/backend/api/services/azure_openai_client.py @@ -1,12 +1,42 @@ -"""Azure OpenAI client service for chat completions""" +"""backend.api.services.azure_openai_client +Azure OpenAI client service for chat completions. + +Supports: +* API key auth (AZURE_OPENAI_MODE=key) +* Entra/managed identity auth (AZURE_OPENAI_MODE=entra) + +Model catalog: +Loaded from YAML at runtime: + /app/configs/models.yaml + +The YAML is a mapping of UI/display keys to {deployment, api_version}, e.g.: + +GPT-5-Mini: + deployment: gpt-5-mini + api_version: 2025-04-01-preview + +DEFAULT_CHAT_MODEL must be one of those keys. +""" + +from __future__ import annotations + +import json +import logging import time -from typing import Dict, List, Any, Optional +from pathlib import Path +import base64 +from typing import Any, Dict, List, Optional, Tuple + +import yaml + from azure.identity import DefaultAzureCredential, get_bearer_token_provider from openai import AzureOpenAI from ..core.config import settings +logger = logging.getLogger(__name__) + # Token cache TTL in seconds (9 minutes) TOKEN_CACHE_TTL = 9 * 60 @@ -32,70 +62,157 @@ class AzureOpenAIClient: """Client for Azure OpenAI chat completions""" def __init__(self): + self._config_error: Optional[str] = None self.default_model = settings.DEFAULT_CHAT_MODEL - # Create token provider for Azure OpenAI using DefaultAzureCredential - # Wrapped with caching to avoid fetching a new token on every request - if not settings.AZURE_OPENAI_API_KEY and not settings.USE_ENTRA_AUTH: - raise ValueError("Azure OpenAI API key or Entra auth must be configured") + self._auth_type = self._resolve_auth_type() + self._endpoint = self._resolve_endpoint() + self._api_key = settings.AZURE_OPENAI_API_KEY - if settings.USE_ENTRA_AUTH: - self._credential = DefaultAzureCredential() - self._token_provider = CachedTokenProvider( - get_bearer_token_provider( - self._credential, "https://cognitiveservices.azure.com/.default" + self._token_provider: Optional[CachedTokenProvider] = None + if self._auth_type == "entra": + if not DefaultAzureCredential or not get_bearer_token_provider: + self._config_error = ( + "AZURE_OPENAI_MODE=entra requires azure-identity to be installed" ) - ) + else: + # Create token provider for Azure OpenAI using DefaultAzureCredential + # Wrapped with caching to avoid fetching a new token on every request + credential = DefaultAzureCredential() + self._token_provider = CachedTokenProvider( + get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" + ) + ) + + self.model_configs = self._load_model_configs() + self.default_model = self._resolve_default_model(self.default_model) + + # Cache official clients by (endpoint, api_version, auth_type) + self._official_clients: Dict[Tuple[str, str, str], AzureOpenAI] = {} + + + # --------------------------------------------------------------------- + # Configuration + # --------------------------------------------------------------------- - self.model_configs = { - "gpt-4.1-mini": { - "endpoint": settings.AZURE_OPENAI_GPT41_MINI_ENDPOINT - or settings.AZURE_OPENAI_ENDPOINT, - "deployment": settings.AZURE_OPENAI_GPT41_MINI_DEPLOYMENT, - "api_version": settings.AZURE_OPENAI_GPT41_MINI_API_VERSION, - }, - "gpt-5-mini": { - "endpoint": settings.AZURE_OPENAI_GPT5_MINI_ENDPOINT - or settings.AZURE_OPENAI_ENDPOINT, - "deployment": settings.AZURE_OPENAI_GPT5_MINI_DEPLOYMENT, - "api_version": settings.AZURE_OPENAI_GPT5_MINI_API_VERSION, - }, - } - - self._official_clients: Dict[str, AzureOpenAI] = {} + + @staticmethod + def _resolve_auth_type() -> str: + """Return key|entra. + + New config: AZURE_OPENAI_MODE + Legacy config: USE_ENTRA_AUTH + """ + t = (getattr(settings, "AZURE_OPENAI_MODE", None) or "").lower().strip() + if t in {"key", "entra"}: + return t + # Legacy fallback + if getattr(settings, "USE_ENTRA_AUTH", False): + return "entra" + return "key" + + @staticmethod + def _resolve_endpoint() -> Optional[str]: + return settings.AZURE_OPENAI_ENDPOINT + + def _strip_outer_quotes(self, s: str) -> str: + s = s.strip() + if len(s) >= 2 and ((s[0] == s[-1] == '"') or (s[0] == s[-1] == "'")): + return s[1:-1] + return s + + def _load_models_yaml(self) -> Dict[str, Any]: + """Load model catalog from /app/configs/models.yaml. + + This file is expected to be mounted in docker-compose so changes can be + applied without rebuilding the image. + """ + path = Path("/app/configs/models.yaml") + if not path.exists(): + logger.warning("Azure OpenAI model catalog not found at %s", path) + return {} + try: + data = yaml.safe_load(path.read_text(encoding="utf-8")) + if not isinstance(data, dict): + logger.warning("Invalid models.yaml format (expected mapping): %s", type(data)) + return {} + return data + except Exception as e: + logger.exception("Failed to load Azure OpenAI model catalog from %s: %s", path, e) + return {} + + def _load_model_configs(self) -> Dict[str, Dict[str, str]]: + """Build model configs keyed by UI/display name.""" + models = self._load_models_yaml() + cfg: Dict[str, Dict[str, str]] = {} + for display_name, meta in models.items(): + if not isinstance(meta, dict): + continue + deployment = meta.get("deployment") + api_version = meta.get("api_version") + if not deployment or not api_version: + continue + cfg[str(display_name)] = { + "endpoint": self._endpoint or "", + "deployment": str(deployment), + "api_version": str(api_version), + } + + if cfg: + return cfg + + return {} + + def _resolve_default_model(self, desired: str) -> str: + if desired in self.model_configs: + return desired + # If configured default doesn't exist, fall back to first configured model + for k in self.model_configs.keys(): + return k + return desired def _get_model_config(self, model: str) -> Dict[str, str]: """Get configuration for a specific model""" if model in self.model_configs: return self.model_configs[model] - return self.model_configs["gpt-5-mini"] + # fallback to first configured model + for _, cfg in self.model_configs.items(): + return cfg + raise ValueError("No Azure OpenAI models are configured") def _get_official_client(self, model: str) -> AzureOpenAI: """Get official Azure OpenAI client instance""" - if model not in self._official_clients: - config = self._get_model_config(model) - if not config.get("endpoint"): - raise ValueError( - f"Azure OpenAI endpoint not configured for model {model}" - ) - - azure_openai_kwargs = { - "azure_endpoint": config["endpoint"], - "api_version": config["api_version"], + config = self._get_model_config(model) + endpoint = config.get("endpoint") + api_version = config.get("api_version") + if not endpoint or not api_version: + raise ValueError(f"Azure OpenAI endpoint/api_version not configured for model {model}") + + cache_key = (endpoint, api_version, self._auth_type) + if cache_key not in self._official_clients: + azure_openai_kwargs: Dict[str, Any] = { + "azure_endpoint": endpoint, + "api_version": api_version, } - if settings.USE_ENTRA_AUTH: - azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider - - if settings.AZURE_OPENAI_API_KEY: - azure_openai_kwargs["api_key"] = settings.AZURE_OPENAI_API_KEY - self._official_clients[model] = AzureOpenAI(**azure_openai_kwargs) + if self._auth_type == "entra": + if not self._token_provider: + raise ValueError(self._config_error or "Azure AD token provider not configured") + azure_openai_kwargs["entra_token_provider"] = self._token_provider + else: + # key auth + if not self._api_key: + raise ValueError("AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY") + azure_openai_kwargs["api_key"] = self._api_key + + self._official_clients[cache_key] = AzureOpenAI(**azure_openai_kwargs) - return self._official_clients[model] + return self._official_clients[cache_key] def _build_messages( self, user_message: str, system_prompt: Optional[str] = None - ) -> List[Dict[str, str]]: + ) -> List[Dict[str, Any]]: """Build message list for chat completion""" messages = [] if system_prompt: @@ -105,7 +222,7 @@ def _build_messages( async def chat_completion( self, - messages: List[Dict[str, str]], + messages: List[Dict[str, Any]], model: Optional[str] = None, max_tokens: int = 1000, temperature: float = 0.7, @@ -145,7 +262,9 @@ async def chat_completion( "stream": stream, } - if model != "gpt-5-mini": + # gpt-5 deployments may reject temperature/max_tokens in some previews. + # We gate this by the *deployment* name because the UI key can differ. + if deployment != "gpt-5-mini": request_kwargs["max_tokens"] = max_tokens request_kwargs["temperature"] = temperature @@ -203,6 +322,51 @@ async def simple_chat( print(f"Error in simple chat: {e}") return f"I apologize, but I encountered an error while processing your request. Please try again later. (Error: {str(e)})" + async def multimodal_chat( + self, + user_text: str, + images: List[Tuple[bytes, str]], + system_prompt: Optional[str] = None, + model: Optional[str] = None, + max_tokens: int = 1000, + temperature: float = 0.0, + ) -> str: + """Send a single user message with multiple attached images. + + `images` items are (bytes, mime_type) where mime_type is e.g. "image/png". + """ + try: + parts: List[Dict[str, Any]] = [{"type": "text", "text": user_text}] + for b, mime in images or []: + if not b: + continue + b64 = base64.b64encode(b).decode("utf-8") + parts.append( + { + "type": "image_url", + "image_url": {"url": f"data:{mime};base64,{b64}"}, + } + ) + + messages: List[Dict[str, Any]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": parts}) + + response = await self.chat_completion( + messages=messages, + model=model, + max_tokens=max_tokens, + temperature=temperature, + ) + return response["choices"][0]["message"]["content"] + except Exception as e: + print(f"Error in multimodal_chat: {e}") + return ( + "I apologize, but I encountered an error while processing your request. " + f"Please try again later. (Error: {str(e)})" + ) + async def streaming_chat( self, user_message: str, @@ -219,16 +383,19 @@ async def streaming_chat( deployment = self._get_model_config(model)["deployment"] client = self._get_official_client(model) - response = client.chat.completions.create( - stream=True, - messages=messages, - max_tokens=max_tokens, - temperature=temperature, - top_p=top_p, - frequency_penalty=0.0, - presence_penalty=0.0, - model=deployment, - ) + request_kwargs: Dict[str, Any] = { + "stream": True, + "messages": messages, + "top_p": top_p, + "frequency_penalty": 0.0, + "presence_penalty": 0.0, + "model": deployment, + } + if deployment != "gpt-5-mini": + request_kwargs["max_tokens"] = max_tokens + request_kwargs["temperature"] = temperature + + response = client.chat.completions.create(**request_kwargs) for update in response: if update.choices: @@ -334,16 +501,40 @@ async def chat_with_context( def get_available_models(self) -> List[str]: """Get list of available models that are properly configured""" - return [ - model - for model, config in self.model_configs.items() - if config.get("endpoint") - ] + out: List[str] = [] + for model, config in self.model_configs.items(): + if not config.get("endpoint") or not config.get("deployment") or not config.get("api_version"): + continue + out.append(model) + return out def is_configured(self) -> bool: """Check if Azure OpenAI is properly configured""" - return len(self.get_available_models()) > 0 + if self._config_error: + return False + if not self.get_available_models(): + return False + + if self._auth_type == "key": + return bool(self._endpoint and self._api_key) + if self._auth_type == "entra": + return bool(self._endpoint and self._token_provider) + return False # Global Azure OpenAI client instance -azure_openai_client = AzureOpenAIClient() +# NOTE: This is used by routers. We intentionally avoid raising during import +# so the API can start up and report configuration issues as 503s. +try: + azure_openai_client = AzureOpenAIClient() +except Exception as e: # pragma: no cover + logger.exception("Failed to initialize AzureOpenAIClient: %s", e) + # Provide a stub that reports not-configured. + class _DisabledAzureOpenAIClient: # type: ignore + def is_configured(self) -> bool: + return False + + def get_available_models(self) -> List[str]: + return [] + + azure_openai_client = _DisabledAzureOpenAIClient() # type: ignore diff --git a/backend/api/services/cit_db_service.py b/backend/api/services/cit_db_service.py index b130b688..141e6e0e 100644 --- a/backend/api/services/cit_db_service.py +++ b/backend/api/services/cit_db_service.py @@ -138,7 +138,7 @@ def __init__(self): # ----------------------- # Generic column ops # ----------------------- - def create_column(self, db_conn_str: str, col: str, col_type: str, table_name: str = "citations") -> None: + def create_column(self, col: str, col_type: str, table_name: str = "citations") -> None: """ Create column on citations table if it doesn't already exist. col should be the exact column name to use (caller may pass snake_case(col)). @@ -165,7 +165,6 @@ def create_column(self, db_conn_str: str, col: str, col_type: str, table_name: s def update_jsonb_column( self, - db_conn_str: str, citation_id: int, col: str, data: Any, @@ -197,7 +196,6 @@ def update_jsonb_column( def update_text_column( self, - db_conn_str: str, citation_id: int, col: str, text_value: str, @@ -230,7 +228,7 @@ def update_text_column( # ----------------------- # Citation row helpers # ----------------------- - def dump_citations_csv(self, db_conn_str: str, table_name: str = "citations") -> bytes: + def dump_citations_csv(self, table_name: str = "citations") -> bytes: """Dump the entire `citations` table as CSV bytes. Intended to be called from async FastAPI routes via @@ -259,7 +257,7 @@ def dump_citations_csv(self, db_conn_str: str, table_name: str = "citations") -> if conn: pass - def get_citation_by_id(self, db_conn_str: str, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]: + def get_citation_by_id(self, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]: """ Return a dict mapping column -> value for the citation row, or None. """ @@ -286,7 +284,7 @@ def get_citation_by_id(self, db_conn_str: str, citation_id: int, table_name: str if conn: pass - def list_citation_ids(self, db_conn_str: str, filter_step=None, table_name: str = "citations") -> List[int]: + def list_citation_ids(self, filter_step=None, table_name: str = "citations") -> List[int]: """ Return list of integer primary keys (id) from citations table ordered by id. """ @@ -333,7 +331,7 @@ def list_citation_ids(self, db_conn_str: str, filter_step=None, table_name: str if conn: pass - def list_fulltext_urls(self, db_conn_str: str, table_name: str = "citations") -> List[str]: + def list_fulltext_urls(self, table_name: str = "citations") -> List[str]: """ Return list of fulltext_url values (non-null) from citations table. """ @@ -350,18 +348,17 @@ def list_fulltext_urls(self, db_conn_str: str, table_name: str = "citations") -> if conn: pass - def update_citation_fulltext(self, db_conn_str: str, citation_id: int, fulltext_path: str) -> int: + def update_citation_fulltext(self, citation_id: int, fulltext_path: str) -> int: """ Backwards-compatible helper used by some routers. Sets `fulltext_url`. """ - return self.update_text_column(db_conn_str, citation_id, "fulltext_url", fulltext_path) + return self.update_text_column(citation_id, "fulltext_url", fulltext_path) # ----------------------- # Upload fulltext and compute md5 # ----------------------- def attach_fulltext( self, - db_conn_str: str, citation_id: int, azure_path: str, file_bytes: bytes, @@ -373,7 +370,7 @@ def attach_fulltext( """ table_name = _validate_ident(table_name, kind="table_name") # create columns if missing - self.create_column(db_conn_str, "fulltext_url", "TEXT", table_name=table_name) + self.create_column("fulltext_url", "TEXT", table_name=table_name) # compute md5 md5 = hashlib.md5(file_bytes).hexdigest() if file_bytes is not None else "" @@ -383,15 +380,13 @@ def attach_fulltext( cur.execute(f'UPDATE "{table_name}" SET "fulltext_url" = %s WHERE id = %s', (azure_path, int(citation_id))) rows = cur.rowcount conn.commit() - - return rows # ----------------------- # Column get/set helpers # ----------------------- - def get_column_value(self, db_conn_str: str, citation_id: int, column: str, table_name: str = "citations") -> Any: + def get_column_value(self, citation_id: int, column: str, table_name: str = "citations") -> Any: """ Return the value stored in `column` for the citation row (or None). """ @@ -418,18 +413,18 @@ def get_column_value(self, db_conn_str: str, citation_id: int, column: str, tabl if conn: pass - def set_column_value(self, db_conn_str: str, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int: + def set_column_value(self, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int: """ Generic setter for a citation row column. Will create a TEXT column if it doesn't exist. """ # For simplicity, create a TEXT column. Callers that need JSONB should use update_jsonb_column. - self.create_column(db_conn_str, column, "TEXT", table_name=table_name) - return self.update_text_column(db_conn_str, citation_id, column, value if value is not None else None, table_name=table_name) + self.create_column(column, "TEXT", table_name=table_name) + return self.update_text_column(citation_id, column, value if value is not None else None, table_name=table_name) # ----------------------- # Per-upload table lifecycle helpers # ----------------------- - def drop_table(self, db_conn_str: str, table_name: str, cascade: bool = True) -> None: + def drop_table(self, table_name: str, cascade: bool = True) -> None: """Drop a screening table in the shared database.""" table_name = _validate_ident(table_name, kind="table_name") conn = None @@ -445,7 +440,6 @@ def drop_table(self, db_conn_str: str, table_name: str, cascade: bool = True) -> def create_table_and_insert_sync( self, - db_conn_str: str, table_name: str, columns: List[str], rows: List[Dict[str, Any]], diff --git a/backend/api/services/document_service.py b/backend/api/services/document_service.py index da32dbb5..571ed429 100644 --- a/backend/api/services/document_service.py +++ b/backend/api/services/document_service.py @@ -1,275 +1,42 @@ -""" -Main Document Processing Service - -This service orchestrates different document processors (Docling, Azure Document Intelligence, etc.) -and provides a unified interface for document conversion. It can automatically choose the best -processor for a given document or allow explicit processor selection. -""" +"""backend.api.services.document_service -import asyncio -from typing import Dict, Any, Optional, List +NOTE: CAN-SR previously carried a multi-processor document service (Docling + Azure). +We are **removing Docling** and keeping only Azure Document Intelligence. -from schemas.enums import ProcessorType -from .processors.docling import DoclingService -from .processors.azure_doc_intelligence.azure_doc_intelligence_service import ( - AzureDocIntelligenceService, -) +This module remains as a small compatibility wrapper for any older code paths, +but new code should prefer importing `azure_docint_client` directly. +""" +from __future__ import annotations -class DocumentService: - """Main service for document processing with multiple processor support""" +from typing import Any, Dict, Optional - def __init__(self): - self.docling_service = DoclingService() - self.azure_doc_intelligence_service = AzureDocIntelligenceService() +from .azure_docint_client import azure_docint_client - self.available_processors = self._check_processor_availability() - def _check_processor_availability(self) -> Dict[str, bool]: - """Check which processors are available""" - return { - ProcessorType.DOCLING.value: True, - ProcessorType.AZURE_DOC_INTELLIGENCE.value: self.azure_doc_intelligence_service.is_available(), - } +class DocumentService: + """Compatibility wrapper around Azure Document Intelligence only.""" async def convert_document_to_markdown( self, source: str, source_type: str = "file", - processor: ProcessorType = ProcessorType.AUTO, - **kwargs + **kwargs: Any, ) -> Dict[str, Any]: - """ - Convert document to markdown using specified or auto-selected processor - - Args: - source: File path or URL - source_type: "file" or "url" - processor: Which processor to use (auto, docling, azure_doc_intelligence) - **kwargs: Additional processor-specific arguments - - Returns: - Dict with conversion results including processor used - """ - - if processor == ProcessorType.AUTO: - processor = self._auto_select_processor(source, source_type) - - if processor == ProcessorType.AZURE_DOC_INTELLIGENCE: - if not self.available_processors[ - ProcessorType.AZURE_DOC_INTELLIGENCE.value - ]: - processor = ProcessorType.DOCLING - result = await self.docling_service.convert_document_to_markdown( - source, source_type, **kwargs - ) - result["processor_used"] = ProcessorType.DOCLING.value - result["processor_fallback"] = True - result["fallback_reason"] = "Azure Document Intelligence not available" - else: - result = await self.azure_doc_intelligence_service.convert_document_to_markdown( - source, source_type, **kwargs - ) - result["processor_used"] = ProcessorType.AZURE_DOC_INTELLIGENCE.value - - else: - result = await self.docling_service.convert_document_to_markdown( - source, source_type, **kwargs - ) - result["processor_used"] = ProcessorType.DOCLING.value - - return result - - def _auto_select_processor(self, source: str, source_type: str) -> ProcessorType: - """ - Automatically select the best processor for the document - - Logic: - - Default to Azure Document Intelligence for all documents - - For specific document types that work better with Docling, use Docling - - If Azure is not available, fall back to Docling - """ - - if not self.available_processors[ProcessorType.AZURE_DOC_INTELLIGENCE.value]: - return ProcessorType.DOCLING - - if source_type == "file": - source_lower = source.lower() - - return ProcessorType.AZURE_DOC_INTELLIGENCE - - async def get_processor_capabilities(self) -> Dict[str, Any]: - """Get information about available processors and their capabilities""" - return { - "available_processors": self.available_processors, - "processors": { - ProcessorType.AZURE_DOC_INTELLIGENCE.value: { - "name": "Azure Document Intelligence", - "description": "Primary processor with superior accuracy for all document types, especially forms, tables, and structured documents", - "strengths": [ - "Table extraction", - "Form fields", - "Key-value pairs", - "Handwriting", - "Figure/chart detection", - "General documents", - "Complex layouts", - ], - "features": [ - "Markdown output", - "Table extraction", - "Figure extraction with captions", - "Downloadable figure images", - "Bounding regions", - "Key-value pairs", - ], - "available": self.available_processors[ - ProcessorType.AZURE_DOC_INTELLIGENCE.value - ], - }, - ProcessorType.DOCLING.value: { - "name": "Docling", - "description": "Fast and reliable fallback processor for general documents when Azure is unavailable", - "strengths": [ - "Academic papers", - "Mixed content", - "Fast processing", - "Always available", - ], - "available": self.available_processors[ProcessorType.DOCLING.value], - }, - }, - "auto_selection": { - "description": "Automatically chooses the best processor based on document characteristics", - "default_processor": ProcessorType.AZURE_DOC_INTELLIGENCE.value, - "fallback_processor": ProcessorType.DOCLING.value, - }, - } - - async def get_conversion_by_id( - self, conversion_id: str - ) -> Optional[Dict[str, Any]]: - """Get conversion info by ID from any processor""" - - result = await self.docling_service.get_conversion_by_id(conversion_id) - if result: - return result - - result = await self.azure_doc_intelligence_service.get_conversion_by_id( - conversion_id - ) - if result: - return result - - return None - - async def get_markdown_content( - self, conversion_id: str, processor_used: Optional[str] = None - ) -> Optional[str]: - """ - Get markdown content by conversion ID from the specific processor that was used - - Args: - conversion_id: The conversion ID to retrieve content for - processor_used: The processor that was used (if known) - improves efficiency - """ - - # If we know which processor was used, check that one first - if processor_used: - if processor_used == ProcessorType.AZURE_DOC_INTELLIGENCE.value: - content = ( - await self.azure_doc_intelligence_service.get_markdown_content( - conversion_id - ) - ) - if content: - return content - elif processor_used == ProcessorType.DOCLING.value: - content = await self.docling_service.get_markdown_content(conversion_id) - if content: - return content - - # Fallback: Try both processors (for backward compatibility or if processor_used is unknown) - # Check Azure first since it's our default - content = await self.azure_doc_intelligence_service.get_markdown_content( - conversion_id + if not azure_docint_client or not azure_docint_client.is_available(): + return { + "success": False, + "error": "Azure Document Intelligence is not configured", + "processor_used": "azure_doc_intelligence", + } + + result = await azure_docint_client.convert_document_to_markdown( + source, source_type=source_type, **kwargs ) - if content: - return content - - content = await self.docling_service.get_markdown_content(conversion_id) - if content: - return content - - return None - - async def get_figures_for_conversion( - self, conversion_id: str - ) -> Optional[List[Dict[str, Any]]]: - """ - Get all figures metadata for a specific conversion - - This method works for conversions processed with either Docling or Azure Document Intelligence - - Args: - conversion_id: The conversion ID - - Returns: - List of figure metadata dictionaries or None if not found - """ - # Try Docling first - figures = await self.docling_service.get_figures_for_conversion(conversion_id) - if figures is not None: - return figures - - # Try Azure Document Intelligence - figures = await self.azure_doc_intelligence_service.get_figures_for_conversion( - conversion_id - ) - if figures is not None: - return figures - - return None - - async def get_raw_analysis_result( - self, conversion_id: str - ) -> Optional[Dict[str, Any]]: - """ - Get the complete raw analysis result with ALL bounding boxes - - This is available for documents processed with either Azure Document Intelligence or Docling. - The raw analysis includes all detected elements with their bounding box coordinates: - - Azure DI provides: - - Pages with words, lines, selection marks - - Paragraphs with roles and bounding regions - - Tables with cells and bounding boxes - - Figures with bounding regions and captions - - Sections and structural information - - Docling provides: - - Pages with dimensions - - Text items (paragraphs) with bounding regions and roles - - Tables with cells and bounding boxes - - Pictures/figures with bounding regions - - Document structure (body, furniture, groups) - - Args: - conversion_id: The conversion ID - - Returns: - Complete analysis result dictionary or None if not found - """ - # Try Azure Document Intelligence first - result = await self.azure_doc_intelligence_service.get_raw_analysis_result( - conversion_id - ) - if result: - return result - - # Try Docling - result = await self.docling_service.get_raw_analysis_result(conversion_id) - if result: - return result + result["processor_used"] = "azure_doc_intelligence" + return result - return None + async def get_raw_analysis_result(self, conversion_id: str) -> Optional[Dict[str, Any]]: + if not azure_docint_client: + return None + return await azure_docint_client.get_raw_analysis_result(conversion_id) diff --git a/backend/api/services/grobid_service.py b/backend/api/services/grobid_service.py index 9e2b767d..6bed76ff 100644 --- a/backend/api/services/grobid_service.py +++ b/backend/api/services/grobid_service.py @@ -39,18 +39,45 @@ def __init__(self): "GROBID_SERVICE_URL", "http://grobid-service:8000" ) - grobid_client = GrobidClient( - grobid_server=self.base_service_url, - batch_size=1000, - coordinates=["p", "s", "persName", "biblStruct", "figure", "formula", "head", "note", "title", "ref", - "affiliation"], - sleep_time=5, - timeout=240, - check_server=True - ) - self.grobid_client = grobid_client + # IMPORTANT: + # Do not check server availability at import time. This repo is often imported + # in environments where the grobid container is not running (dev/test), and + # failing hard on import breaks unrelated endpoints. + try: + grobid_client = GrobidClient( + grobid_server=self.base_service_url, + batch_size=1000, + coordinates=[ + "p", + "s", + "persName", + "biblStruct", + "figure", + "formula", + "head", + "note", + "title", + "ref", + "affiliation", + ], + sleep_time=5, + timeout=240, + check_server=False, + ) + self.grobid_client = grobid_client + except Exception as e: + logger.error( + "Failed to initialize GrobidClient (service may be down): %s", + e, + ) + self.grobid_client = None + + def is_available(self) -> bool: + return self.grobid_client is not None async def process_structure(self, input_path) -> (dict, list): + if not self.grobid_client: + raise RuntimeError("GROBID client is not available (service not configured or down)") pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument", input_path, consolidate_header=True, @@ -119,4 +146,9 @@ async def get_pages(self, text): return pages # Global instance -grobid_service = GrobidService() +try: + grobid_service = GrobidService() +except Exception as e: # pragma: no cover + logger.error("Failed to initialize GrobidService: %s", e) + grobid_service = GrobidService.__new__(GrobidService) # type: ignore + grobid_service.grobid_client = None # type: ignore diff --git a/backend/api/services/postgres_auth.py b/backend/api/services/postgres_auth.py index 9334c520..dc676cbd 100644 --- a/backend/api/services/postgres_auth.py +++ b/backend/api/services/postgres_auth.py @@ -1,17 +1,38 @@ -""" -PostgreSQL authentication helper using Azure Entra ID (DefaultAzureCredential). +"""backend.api.services.postgres_auth + +PostgreSQL connection helper supporting three runtime modes. + +Configuration model: +* POSTGRES_MODE selects behavior: docker | local | azure +* Connection settings are provided via a single set of env vars: + - POSTGRES_HOST + - POSTGRES_DATABASE + - POSTGRES_USER + - POSTGRES_PASSWORD + +Auth behavior: +* docker/local: password auth (POSTGRES_PASSWORD required) +* azure: Entra token auth via DefaultAzureCredential (password ignored) + +Behavior: +* Only try the configured POSTGRES_MODE (no fallback). -This module provides a centralized way to connect to Azure Database for PostgreSQL -using Entra ID authentication, with fallback to connection string for local development. +Notes: +* POSTGRES_URI is deprecated and intentionally not used. """ -from typing import Optional +import os +from typing import Optional, Dict, Any import psycopg2 from ..core.config import settings import logging import datetime -from azure.identity import DefaultAzureCredential + +try: + from azure.identity import DefaultAzureCredential +except Exception: # pragma: no cover + DefaultAzureCredential = None # type: ignore logger = logging.getLogger(__name__) @@ -24,7 +45,7 @@ class PostgresServer: def __init__(self): self._verify_config() - self._credential = DefaultAzureCredential() if settings.AZURE_DB else None + self._credential = DefaultAzureCredential() if self._mode() == "azure" else None self._token: Optional[str] = None self._token_expiration: int = 0 self._conn = None @@ -33,14 +54,11 @@ def __init__(self): def conn(self): """Return an open connection, reconnecting only when necessary.""" if self._conn is None or self._conn.closed: - print("local database") self._conn = self._connect() - elif settings.AZURE_DB and self._is_token_expired(): + elif self._mode() == "azure" and self._is_token_expired(): logger.info("Azure token expired — reconnecting to PostgreSQL") - print("cloud database") self.close() self._conn = self._connect() - print(self._conn) return self._conn def close(self): @@ -55,11 +73,22 @@ def close(self): @staticmethod def _verify_config(): """Validate that all required PostgreSQL settings are present.""" - required = [settings.POSTGRES_HOST, settings.POSTGRES_DATABASE, settings.POSTGRES_USER] + mode = (settings.POSTGRES_MODE or "").lower().strip() + if mode not in {"local", "docker", "azure"}: + raise RuntimeError("POSTGRES_MODE must be one of: local, docker, azure") + + # Validate selected profile minimally; the rest is validated when building kwargs. + try: + prof = settings.postgres_profile(mode) + except Exception as e: + raise RuntimeError(str(e)) + + required = [prof.get("host"), prof.get("database"), prof.get("user")] if not all(required): - raise RuntimeError("POSTGRES_HOST, POSTGRES_DATABASE, and POSTGRES_USER are required") - if not settings.AZURE_DB and not settings.POSTGRES_PASSWORD: - raise RuntimeError("POSTGRES_PASSWORD is required when AZURE_DB is False") + raise RuntimeError(f"{mode} profile requires host, database and user") + + if mode in {"docker", "local"} and not prof.get("password"): + raise RuntimeError(f"{mode} mode requires POSTGRES_PASSWORD") def _is_token_expired(self) -> bool: """Check whether the cached Azure token needs refreshing.""" @@ -70,36 +99,76 @@ def _refresh_azure_token(self) -> str: """Return a valid Azure token, fetching a new one only if expired.""" if self._is_token_expired(): logger.info("Fetching fresh Azure PostgreSQL token") + if not self._credential: + raise RuntimeError( + "Azure credential is not configured. Ensure POSTGRES_MODE=azure and that " + "DefaultAzureCredential can authenticate in this environment." + ) token = self._credential.get_token(self._AZURE_POSTGRES_SCOPE) self._token = token.token self._token_expiration = token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS return self._token - def _build_connect_kwargs(self) -> dict: - """Assemble psycopg2.connect() keyword arguments from settings.""" - kwargs = { - "host": settings.POSTGRES_HOST, - "database": settings.POSTGRES_DATABASE, - "user": settings.POSTGRES_USER, - "port": settings.POSTGRES_PORT, + @staticmethod + def _mode() -> str: + return (settings.POSTGRES_MODE or "docker").lower().strip() + + @staticmethod + def _has_local_fallback() -> bool: + return False + + def _candidate_kwargs(self, mode: str) -> Dict[str, Any]: + """Build connect kwargs for a given mode based on POSTGRES_* env vars.""" + prof = settings.postgres_profile(mode) + + kwargs: Dict[str, Any] = { + "host": prof.get("host"), + "database": prof.get("database"), + "user": prof.get("user"), + "port": int(prof.get("port") or 5432), + # Fail fast so connection errors surface quickly + "connect_timeout": int(os.getenv("POSTGRES_CONNECT_TIMEOUT", "3")), } - if settings.POSTGRES_SSL_MODE: - kwargs["sslmode"] = settings.POSTGRES_SSL_MODE - if settings.AZURE_DB: + + sslmode = prof.get("sslmode") + if sslmode: + kwargs["sslmode"] = sslmode + + if prof.get("mode") == "azure": kwargs["password"] = self._refresh_azure_token() - elif settings.POSTGRES_PASSWORD: - kwargs["password"] = settings.POSTGRES_PASSWORD + else: + if not prof.get("password"): + raise RuntimeError(f"{mode} profile requires password") + kwargs["password"] = prof.get("password") + + # Sanity checks + required = [kwargs.get("host"), kwargs.get("database"), kwargs.get("user"), kwargs.get("port")] + if not all(required): + raise RuntimeError(f"Incomplete Postgres config for mode={mode}") + return kwargs + def _connect_with_mode(self, mode: str): + kwargs = self._candidate_kwargs(mode) + safe_kwargs = {k: ("***" if k == "password" else v) for k, v in kwargs.items()} + logger.info("Connecting to Postgres (mode=%s) %s", mode, safe_kwargs) + return psycopg2.connect(**kwargs) + def _connect(self): """Create a new psycopg2 connection.""" - return psycopg2.connect(**self._build_connect_kwargs()) + primary_mode = self._mode() + try: + return self._connect_with_mode(primary_mode) + except Exception as e: + logger.error("Postgres connect failed (mode=%s): %s", primary_mode, e, exc_info=True) + raise psycopg2.OperationalError( + f"Could not connect to Postgres for mode={primary_mode}" + ) def __repr__(self) -> str: status = "open" if self._conn and not self._conn.closed else "closed" return ( - f"" + f"" ) postgres_server = PostgresServer() \ No newline at end of file diff --git a/backend/api/services/sr_db_service.py b/backend/api/services/sr_db_service.py index 1dfaedce..4b4ccf14 100644 --- a/backend/api/services/sr_db_service.py +++ b/backend/api/services/sr_db_service.py @@ -27,7 +27,7 @@ def __init__(self): # Service is stateless; connection strings passed per-call pass - def ensure_table_exists(self, db_conn_str: str) -> None: + def ensure_table_exists(self) -> None: """ Ensure the systematic_reviews table exists in PostgreSQL. Creates the table if it doesn't exist. @@ -57,8 +57,6 @@ def ensure_table_exists(self, db_conn_str: str) -> None: """ cur.execute(create_table_sql) conn.commit() - - logger.info("Ensured systematic_reviews table exists") except Exception as e: @@ -157,7 +155,6 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[ def create_systematic_review( self, - db_conn_str: str, name: str, description: Optional[str], criteria_str: Optional[str], @@ -236,7 +233,7 @@ def create_systematic_review( if conn: pass - def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: + def add_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: """ Add a user id to the SR's users list. Enforces that the SR exists and is visible; requester must be a member or owner. @@ -244,12 +241,12 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ """ - sr = self.get_systematic_review(db_conn_str, sr_id) + sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") # Check permission - has_perm = self.user_has_sr_permission(db_conn_str, sr_id, requester_id) + has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") @@ -294,19 +291,19 @@ def add_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_ if conn: pass - def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: + def remove_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: """ Remove a user id from the SR's users list. Owner cannot be removed. Enforces requester permissions (must be a member or owner). """ - sr = self.get_systematic_review(db_conn_str, sr_id) + sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") # Check permission - has_perm = self.user_has_sr_permission(db_conn_str, sr_id, requester_id) + has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") @@ -354,7 +351,7 @@ def remove_user(self, db_conn_str: str, sr_id: str, target_user_id: str, request if conn: pass - def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> bool: + def user_has_sr_permission(self, sr_id: str, user_id: str) -> bool: """ Check whether the given user_id is a member (in 'users') or the owner of the SR. Returns True if the SR exists and the user is present in the SR's users list or is the owner. @@ -363,7 +360,7 @@ def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> """ - doc = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) + doc = self.get_systematic_review(sr_id, ignore_visibility=True) if not doc: return False @@ -372,7 +369,7 @@ def user_has_sr_permission(self, db_conn_str: str, sr_id: str, user_id: str) -> return True return False - def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, Any], criteria_str: str, requester_id: str) -> Dict[str, Any]: + def update_criteria(self, sr_id: str, criteria_obj: Dict[str, Any], criteria_str: str, requester_id: str) -> Dict[str, Any]: """ Update the criteria fields (criteria, criteria_yaml, criteria_parsed, updated_at). The requester must be a member or owner. @@ -380,12 +377,12 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, """ - sr = self.get_systematic_review(db_conn_str, sr_id) + sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") # Check permission - has_perm = self.user_has_sr_permission(db_conn_str, sr_id, requester_id) + has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") @@ -419,7 +416,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, # Return fresh doc - doc = self.get_systematic_review(db_conn_str, sr_id) + doc = self.get_systematic_review(sr_id) return doc except HTTPException: @@ -431,7 +428,7 @@ def update_criteria(self, db_conn_str: str, sr_id: str, criteria_obj: Dict[str, if conn: pass - def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> List[Dict[str, Any]]: + def list_systematic_reviews_for_user(self, user_email: str) -> List[Dict[str, Any]]: """ Return all SR documents where the user is a member (regardless of visible flag). """ @@ -484,7 +481,7 @@ def list_systematic_reviews_for_user(self, db_conn_str: str, user_email: str) -> if conn: pass - def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]: + def get_systematic_review(self, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]: """ Return SR document by id. Returns None if not found. If ignore_visibility is False, only returns visible SRs. @@ -535,14 +532,14 @@ def get_systematic_review(self, db_conn_str: str, sr_id: str, ignore_visibility: if conn: pass - def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]: + def set_visibility(self, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]: """ Set the visible flag on the SR. Only owner is allowed to change visibility. Returns update metadata. """ - sr = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) + sr = self.get_systematic_review(sr_id, ignore_visibility=True) if not sr: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") @@ -573,26 +570,26 @@ def set_visibility(self, db_conn_str: str, sr_id: str, visible: bool, requester_ if conn: pass - def soft_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: + def soft_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: """ Soft-delete (set visible=False). Only owner may delete. """ - return self.set_visibility(db_conn_str, sr_id, False, requester_id) + return self.set_visibility(sr_id, False, requester_id) - def undelete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: + def undelete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: """ Undelete (set visible=True). Only owner may undelete. """ - return self.set_visibility(db_conn_str, sr_id, True, requester_id) + return self.set_visibility(sr_id, True, requester_id) - def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_id: str) -> Dict[str, Any]: + def hard_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: """ Permanently remove the SR document. Only owner may hard delete. Returns deletion metadata (deleted_count). """ - sr = self.get_systematic_review(db_conn_str, sr_id, ignore_visibility=True) + sr = self.get_systematic_review(sr_id, ignore_visibility=True) if not sr: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") @@ -618,7 +615,7 @@ def hard_delete_systematic_review(self, db_conn_str: str, sr_id: str, requester_ pass - def update_screening_db_info(self, db_conn_str: str, sr_id: str, screening_db: Dict[str, Any]) -> None: + def update_screening_db_info(self, sr_id: str, screening_db: Dict[str, Any]) -> None: """ Update the screening_db field in the SR document with screening database metadata. """ @@ -645,7 +642,7 @@ def update_screening_db_info(self, db_conn_str: str, sr_id: str, screening_db: D if conn: pass - def clear_screening_db_info(self, db_conn_str: str, sr_id: str) -> None: + def clear_screening_db_info(self, sr_id: str) -> None: """ Remove the screening_db field from the SR document. """ diff --git a/backend/api/services/storage.py b/backend/api/services/storage.py index b7595416..b5d4f2b5 100644 --- a/backend/api/services/storage.py +++ b/backend/api/services/storage.py @@ -1,53 +1,96 @@ -"""Azure Blob Storage service for user data management""" +"""backend.api.services.storage + +Storage abstraction for CAN-SR. + +Supported backends (selected via STORAGE_MODE): +* local - Local filesystem storage (backed by docker compose volume) +* azure - Azure Blob Storage via **account name + key** (strict) +* entra - Azure Blob Storage via **DefaultAzureCredential** (Entra/Managed Identity) (strict) + +Routers should not access Azure SDK objects directly. +""" + +from __future__ import annotations import json import logging +import os import uuid from datetime import datetime, timezone -from typing import Dict, List, Optional, Any - -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient -from azure.core.exceptions import ResourceNotFoundError +from pathlib import Path +from typing import Any, Dict, List, Optional, Protocol, Tuple + +try: + from azure.core.exceptions import ResourceNotFoundError + from azure.identity import DefaultAzureCredential + from azure.storage.blob import BlobServiceClient +except Exception: # pragma: no cover + # Allow local-storage deployments/environments to import without azure packages. + ResourceNotFoundError = Exception # type: ignore + DefaultAzureCredential = None # type: ignore + BlobServiceClient = None # type: ignore from ..core.config import settings -from ..utils.file_hash import calculate_file_hash, create_file_metadata +from ..utils.file_hash import create_file_metadata logger = logging.getLogger(__name__) +class StorageService(Protocol): + """Common API that both Azure and local storage must implement.""" + + container_name: str + + async def create_user_directory(self, user_id: str) -> bool: ... + async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: ... + async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: ... + async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: ... + async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: ... + async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: ... + async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: ... + async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: ... + async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: ... + async def delete_by_path(self, path: str) -> bool: ... + + +# ============================================================================= +# Azure Blob Storage +# ============================================================================= + + class AzureStorageService: - """Service for managing user data in Azure Blob Storage""" + """Service for managing user data in Azure Blob Storage.""" - def __init__(self): - if not settings.AZURE_STORAGE_ACCOUNT_NAME and not settings.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError("AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be configured") + def __init__(self, *, account_url: str | None = None, connection_string: str | None = None, container_name: str): + if not BlobServiceClient: + raise RuntimeError( + "Azure storage libraries are not installed. Install azure-identity and azure-storage-blob, or use STORAGE_MODE=local." + ) - if settings.AZURE_STORAGE_ACCOUNT_NAME: - account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + if bool(account_url) == bool(connection_string): + raise ValueError("Exactly one of account_url or connection_string must be provided") + + if connection_string: + self.blob_service_client = BlobServiceClient.from_connection_string(connection_string) + else: + if not DefaultAzureCredential: + raise RuntimeError( + "azure-identity is not installed. Install azure-identity, or use STORAGE_MODE=azure (connection string) or local." + ) credential = DefaultAzureCredential() - self.blob_service_client = BlobServiceClient( - account_url=account_url, credential=credential - ) - elif settings.AZURE_STORAGE_CONNECTION_STRING: - self.blob_service_client = BlobServiceClient.from_connection_string( - settings.AZURE_STORAGE_CONNECTION_STRING - ) - self.container_name = settings.AZURE_STORAGE_CONTAINER_NAME + self.blob_service_client = BlobServiceClient(account_url=account_url, credential=credential) + self.container_name = container_name self._ensure_container_exists() def _ensure_container_exists(self): - """Ensure the storage container exists""" try: self.blob_service_client.create_container(self.container_name) except Exception: pass async def create_user_directory(self, user_id: str) -> bool: - """Create directory structure for a new user""" try: - # Create user profile profile_data = { "user_id": user_id, "created_at": datetime.now(timezone.utc).isoformat(), @@ -59,70 +102,43 @@ async def create_user_directory(self, user_id: str) -> bool: await self.save_user_profile(user_id, profile_data) # Create placeholder file to establish directory structure - directories = [f"users/{user_id}/documents/"] - - for directory in directories: - blob_name = f"{directory}.placeholder" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - blob_client.upload_blob("", overwrite=True) - + blob_name = f"users/{user_id}/documents/.placeholder" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.upload_blob(b"", overwrite=True) return True except Exception as e: - logger.error(f"Error creating user directory for {user_id}: {e}") + logger.error("Error creating user directory for %s: %s", user_id, e) return False - async def save_user_profile( - self, user_id: str, profile_data: Dict[str, Any] - ) -> bool: - """Save user profile data""" + async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: try: blob_name = f"users/{user_id}/profile.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - profile_json = json.dumps(profile_data, indent=2) - blob_client.upload_blob(profile_json, overwrite=True) + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.upload_blob(json.dumps(profile_data, indent=2), overwrite=True) return True except Exception as e: - logger.error(f"Error saving user profile for {user_id}: {e}") + logger.error("Error saving user profile for %s: %s", user_id, e) return False async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get user profile data""" try: blob_name = f"users/{user_id}/profile.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) blob_data = blob_client.download_blob().readall() return json.loads(blob_data.decode("utf-8")) except ResourceNotFoundError: return None except Exception as e: - logger.error(f"Error getting user profile for {user_id}: {e}") + logger.error("Error getting user profile for %s: %s", user_id, e) return None - async def upload_user_document( - self, user_id: str, filename: str, file_content: bytes - ) -> Optional[str]: - """Upload a document for a user with hash metadata for duplicate detection""" + async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: try: - # Generate unique document ID doc_id = str(uuid.uuid4()) blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - # Upload the file + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) blob_client.upload_blob(file_content, overwrite=True) - # Create and save file metadata with hash file_metadata = create_file_metadata( filename, file_content, @@ -132,235 +148,447 @@ async def upload_user_document( "upload_date": datetime.now(timezone.utc).isoformat(), }, ) - await self.save_file_hash_metadata(user_id, doc_id, file_metadata) - # Update user profile profile = await self.get_user_profile(user_id) if profile: - profile["document_count"] += 1 - profile["storage_used"] += len(file_content) + profile["document_count"] = int(profile.get("document_count", 0)) + 1 + profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return doc_id except Exception as e: - logger.error(f"Error uploading document {filename} for user {user_id}: {e}") + logger.error("Error uploading document %s for user %s: %s", filename, user_id, e) return None - async def get_user_document( - self, user_id: str, doc_id: str, filename: str - ) -> Optional[bytes]: - """Get a user's document""" + async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: try: blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) return blob_client.download_blob().readall() except ResourceNotFoundError: return None except Exception as e: - logger.error(f"Error getting document {doc_id} for user {user_id}: {e}") + logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e) return None async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: - """List all documents for a user - SIMPLIFIED""" try: prefix = f"users/{user_id}/documents/" - blobs = self.blob_service_client.get_container_client( - self.container_name - ).list_blobs(name_starts_with=prefix) + blobs = self.blob_service_client.get_container_client(self.container_name).list_blobs(name_starts_with=prefix) - documents = [] + documents: List[Dict[str, Any]] = [] for blob in blobs: - if not blob.name.endswith(".placeholder"): - # Extract doc_id and filename from blob name - blob_filename = blob.name.replace(prefix, "") - if "_" in blob_filename: - doc_id, filename = blob_filename.split("_", 1) - - # Get hash metadata if available - hash_metadata = await self.get_file_hash_metadata( - user_id, doc_id - ) - - document_info = { - "document_id": doc_id, - "filename": filename, - "file_size": blob.size, - "upload_date": blob.last_modified.isoformat(), - "last_modified": blob.last_modified.isoformat(), - } - - # Add hash information if available - if hash_metadata: - document_info["file_hash"] = hash_metadata.get("file_hash") - document_info["signature"] = hash_metadata.get("signature") - - documents.append(document_info) + if blob.name.endswith(".placeholder"): + continue + blob_filename = blob.name.replace(prefix, "") + if "_" not in blob_filename: + continue + doc_id, filename = blob_filename.split("_", 1) + + hash_metadata = await self.get_file_hash_metadata(user_id, doc_id) + document_info: Dict[str, Any] = { + "document_id": doc_id, + "filename": filename, + "file_size": blob.size, + "upload_date": blob.last_modified.isoformat(), + "last_modified": blob.last_modified.isoformat(), + } + if hash_metadata: + document_info["file_hash"] = hash_metadata.get("file_hash") + document_info["signature"] = hash_metadata.get("signature") + documents.append(document_info) return documents except Exception as e: - logger.error(f"Error listing user documents for {user_id}: {e}") + logger.error("Error listing user documents for %s: %s", user_id, e) return [] - async def delete_user_document( - self, user_id: str, doc_id: str, filename: str - ) -> bool: - """Delete a user's document""" + async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: try: - # Get document size before deletion for profile update doc_blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - doc_blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=doc_blob_name - ) + doc_blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=doc_blob_name) - # Get blob properties to determine size try: - blob_properties = doc_blob_client.get_blob_properties() - doc_size = blob_properties.size + doc_size = doc_blob_client.get_blob_properties().size except ResourceNotFoundError: doc_size = 0 - # Delete the document doc_blob_client.delete_blob() - - # Delete the associated hash metadata await self.delete_file_hash_metadata(user_id, doc_id) - # Update user profile profile = await self.get_user_profile(user_id) if profile: - profile["document_count"] = max(0, profile["document_count"] - 1) - profile["storage_used"] = max(0, profile["storage_used"] - doc_size) + profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1) + profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size)) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return True except Exception as e: - logger.error(f"Error deleting document {doc_id} for user {user_id}: {e}") + logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e) return False - async def calculate_user_storage_usage(self, user_id: str) -> int: - """Calculate actual storage usage for a user""" + async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool: try: - prefix = f"users/{user_id}/documents/" - blobs = self.blob_service_client.get_container_client( - self.container_name - ).list_blobs(name_starts_with=prefix) + blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.upload_blob(json.dumps(file_metadata, indent=2), overwrite=True) + return True + except Exception as e: + logger.error("Error saving file metadata for %s: %s", document_id, e) + return False - total_size = 0 - for blob in blobs: - if not blob.name.endswith(".placeholder"): - total_size += blob.size or 0 + async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]: + try: + blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + metadata_json = blob_client.download_blob().readall().decode("utf-8") + return json.loads(metadata_json) + except ResourceNotFoundError: + return None + except Exception as e: + logger.error("Error getting file metadata for %s: %s", document_id, e) + return None - return total_size + async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: + try: + blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" + blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client.delete_blob() + return True + except ResourceNotFoundError: + return True except Exception as e: - logger.error(f"Error calculating storage usage for user {user_id}: {e}") - return 0 + logger.error("Error deleting file metadata for %s: %s", document_id, e) + return False + + async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: + """Write blob by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client.upload_blob(content, overwrite=True, content_type=content_type) + return True + + async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: + """Read blob by storage path 'container/blob'. Returns (bytes, filename).""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + content = blob_client.download_blob().readall() + filename = os.path.basename(blob) or "download" + return content, filename + + async def delete_by_path(self, path: str) -> bool: + """Delete blob by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client.delete_blob() + return True - async def sync_user_profile_stats(self, user_id: str) -> bool: - """Synchronize user profile statistics with actual storage""" + +# ============================================================================= +# Local filesystem storage +# ============================================================================= + + +class LocalStorageService: + """Local filesystem storage implementation. + + Layout: + {LOCAL_STORAGE_BASE_PATH}/{STORAGE_CONTAINER_NAME}/users/{user_id}/... + """ + + def __init__(self): + self.base_path = Path(settings.LOCAL_STORAGE_BASE_PATH).resolve() + self.container_name = settings.STORAGE_CONTAINER_NAME + (self.base_path / self.container_name).mkdir(parents=True, exist_ok=True) + + def _container_root(self) -> Path: + return self.base_path / self.container_name + + def _user_root(self, user_id: str) -> Path: + return self._container_root() / "users" / str(user_id) + + def _profile_path(self, user_id: str) -> Path: + return self._user_root(user_id) / "profile.json" + + def _doc_path(self, user_id: str, doc_id: str, filename: str) -> Path: + return self._user_root(user_id) / "documents" / f"{doc_id}_{filename}" + + def _metadata_path(self, user_id: str, doc_id: str) -> Path: + return self._user_root(user_id) / "metadata" / f"{doc_id}_metadata.json" + + async def create_user_directory(self, user_id: str) -> bool: try: - documents = await self.list_user_documents(user_id) - actual_storage = await self.calculate_user_storage_usage(user_id) + (self._user_root(user_id) / "documents").mkdir(parents=True, exist_ok=True) + (self._user_root(user_id) / "metadata").mkdir(parents=True, exist_ok=True) - profile = await self.get_user_profile(user_id) - if profile: - profile["document_count"] = len(documents) - profile["storage_used"] = actual_storage - profile["last_updated"] = datetime.now(timezone.utc).isoformat() - await self.save_user_profile(user_id, profile) - return True - return False + if not self._profile_path(user_id).exists(): + profile_data = { + "user_id": user_id, + "created_at": datetime.now(timezone.utc).isoformat(), + "last_updated": datetime.now(timezone.utc).isoformat(), + "document_count": 0, + "storage_used": 0, + } + await self.save_user_profile(user_id, profile_data) + return True except Exception as e: - logger.error(f"Error syncing profile stats for user {user_id}: {e}") + logger.error("Error creating user directory for %s: %s", user_id, e) return False - async def save_file_hash_metadata( - self, user_id: str, document_id: str, file_metadata: Dict[str, Any] - ) -> bool: - """Save file hash metadata for duplicate detection""" + async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: try: - blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - metadata_json = json.dumps(file_metadata, indent=2) - blob_client.upload_blob(metadata_json, overwrite=True) + self._profile_path(user_id).parent.mkdir(parents=True, exist_ok=True) + self._profile_path(user_id).write_text(json.dumps(profile_data, indent=2), encoding="utf-8") return True except Exception as e: - logger.error(f"Error saving file metadata for {document_id}: {e}") + logger.error("Error saving user profile for %s: %s", user_id, e) return False - async def get_file_hash_metadata( - self, user_id: str, document_id: str - ) -> Optional[Dict[str, Any]]: - """Get file hash metadata for a specific document""" + async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: + try: + p = self._profile_path(user_id) + if not p.exists(): + return None + return json.loads(p.read_text(encoding="utf-8")) + except Exception as e: + logger.error("Error getting user profile for %s: %s", user_id, e) + return None + + async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: try: - blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name + await self.create_user_directory(user_id) + doc_id = str(uuid.uuid4()) + doc_path = self._doc_path(user_id, doc_id, filename) + doc_path.parent.mkdir(parents=True, exist_ok=True) + doc_path.write_bytes(file_content) + + file_metadata = create_file_metadata( + filename, + file_content, + { + "document_id": doc_id, + "user_id": user_id, + "upload_date": datetime.now(timezone.utc).isoformat(), + }, ) + await self.save_file_hash_metadata(user_id, doc_id, file_metadata) - metadata_json = blob_client.download_blob().readall().decode("utf-8") - return json.loads(metadata_json) - except ResourceNotFoundError: - return None + profile = await self.get_user_profile(user_id) + if profile: + profile["document_count"] = int(profile.get("document_count", 0)) + 1 + profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content) + profile["last_updated"] = datetime.now(timezone.utc).isoformat() + await self.save_user_profile(user_id, profile) + + return doc_id except Exception as e: - logger.error(f"Error getting file metadata for {document_id}: {e}") + logger.error("Error uploading document %s for user %s: %s", filename, user_id, e) return None - async def get_all_user_file_hashes(self, user_id: str) -> List[Dict[str, Any]]: - """Get all file hash metadata for a user for duplicate detection""" + async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: try: - prefix = f"users/{user_id}/metadata/" - blobs = self.blob_service_client.get_container_client( - self.container_name - ).list_blobs(name_starts_with=prefix) + p = self._doc_path(user_id, doc_id, filename) + if not p.exists(): + return None + return p.read_bytes() + except Exception as e: + logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e) + return None - all_metadata = [] - for blob in blobs: - if blob.name.endswith("_metadata.json"): - try: - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob.name - ) - metadata_json = ( - blob_client.download_blob().readall().decode("utf-8") - ) - metadata = json.loads(metadata_json) - all_metadata.append(metadata) - except Exception as e: - logger.warning(f"Error reading metadata from {blob.name}: {e}") - continue - - return all_metadata + async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: + try: + docs_dir = self._user_root(user_id) / "documents" + if not docs_dir.exists(): + return [] + + documents: List[Dict[str, Any]] = [] + for p in docs_dir.iterdir(): + if not p.is_file(): + continue + name = p.name + if "_" not in name: + continue + doc_id, filename = name.split("_", 1) + stat = p.stat() + + hash_metadata = await self.get_file_hash_metadata(user_id, doc_id) + doc_info: Dict[str, Any] = { + "document_id": doc_id, + "filename": filename, + "file_size": stat.st_size, + "upload_date": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), + "last_modified": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), + } + if hash_metadata: + doc_info["file_hash"] = hash_metadata.get("file_hash") + doc_info["signature"] = hash_metadata.get("signature") + documents.append(doc_info) + + # stable ordering + documents.sort(key=lambda d: d.get("upload_date", ""), reverse=True) + return documents except Exception as e: - logger.error(f"Error getting all file hashes for user {user_id}: {e}") + logger.error("Error listing user documents for %s: %s", user_id, e) return [] - async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: - """Delete file hash metadata when document is deleted""" + async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: try: - blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - blob_client.delete_blob() + p = self._doc_path(user_id, doc_id, filename) + doc_size = p.stat().st_size if p.exists() else 0 + if p.exists(): + p.unlink() + await self.delete_file_hash_metadata(user_id, doc_id) + + profile = await self.get_user_profile(user_id) + if profile: + profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1) + profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size)) + profile["last_updated"] = datetime.now(timezone.utc).isoformat() + await self.save_user_profile(user_id, profile) return True - except ResourceNotFoundError: - # Metadata doesn't exist, which is fine + except Exception as e: + logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e) + return False + + async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool: + try: + p = self._metadata_path(user_id, document_id) + p.parent.mkdir(parents=True, exist_ok=True) + p.write_text(json.dumps(file_metadata, indent=2), encoding="utf-8") return True except Exception as e: - logger.error(f"Error deleting file metadata for {document_id}: {e}") + logger.error("Error saving file metadata for %s: %s", document_id, e) return False + async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]: + try: + p = self._metadata_path(user_id, document_id) + if not p.exists(): + return None + return json.loads(p.read_text(encoding="utf-8")) + except Exception as e: + logger.error("Error getting file metadata for %s: %s", document_id, e) + return None + + async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: + try: + p = self._metadata_path(user_id, document_id) + if p.exists(): + p.unlink() + return True + except Exception as e: + logger.error("Error deleting file metadata for %s: %s", document_id, e) + return False + + async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: + """Write file by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + if container != self.container_name: + raise FileNotFoundError("Container not found") + + p = (self.base_path / container / blob).resolve() + # Prevent path traversal + if not str(p).startswith(str((self.base_path / container).resolve())): + raise FileNotFoundError("Invalid path") + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(content) + return True + + async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: + """Read file by storage path 'container/blob'. Returns (bytes, filename).""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + # Only allow access to our configured local container. + if container != self.container_name: + raise FileNotFoundError("Container not found") + + p = (self.base_path / container / blob).resolve() + # Prevent path traversal + if not str(p).startswith(str((self.base_path / container).resolve())): + raise FileNotFoundError("Invalid path") + if not p.exists() or not p.is_file(): + raise FileNotFoundError("File not found") + + return p.read_bytes(), (p.name or "download") + + async def delete_by_path(self, path: str) -> bool: + """Delete file by storage path 'container/blob'.""" + if not path or "/" not in path: + raise ValueError("Invalid storage path") + container, blob = path.split("/", 1) + + if container != self.container_name: + raise FileNotFoundError("Container not found") + + p = (self.base_path / container / blob).resolve() + if not str(p).startswith(str((self.base_path / container).resolve())): + raise FileNotFoundError("Invalid path") + if not p.exists() or not p.is_file(): + raise FileNotFoundError("File not found") + p.unlink() + return True + + +# ============================================================================= +# Factory +# ============================================================================= + + +def _build_storage_service() -> Optional[StorageService]: + stype = (settings.STORAGE_MODE or "azure").lower().strip() + if stype == "local": + try: + return LocalStorageService() + except Exception as e: + logger.exception("Failed to initialize LocalStorageService: %s", e) + return None + if stype == "azure": + try: + if not settings.AZURE_STORAGE_ACCOUNT_NAME or not settings.AZURE_STORAGE_ACCOUNT_KEY: + raise ValueError("STORAGE_MODE=azure requires AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY") + connection_string = ( + "DefaultEndpointsProtocol=https;" + f"AccountName={settings.AZURE_STORAGE_ACCOUNT_NAME};" + f"AccountKey={settings.AZURE_STORAGE_ACCOUNT_KEY};" + "EndpointSuffix=core.windows.net" + ) + return AzureStorageService( + connection_string=connection_string, + container_name=settings.STORAGE_CONTAINER_NAME, + ) + except Exception as e: + logger.exception("Failed to initialize AzureStorageService (connection string): %s", e) + return None + if stype == "entra": + try: + if not settings.AZURE_STORAGE_ACCOUNT_NAME: + raise ValueError("STORAGE_MODE=entra requires AZURE_STORAGE_ACCOUNT_NAME") + account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + return AzureStorageService( + account_url=account_url, + container_name=settings.STORAGE_CONTAINER_NAME, + ) + except Exception as e: + logger.exception("Failed to initialize AzureStorageService (Entra): %s", e) + return None + + logger.warning("Unsupported STORAGE_MODE=%s; storage disabled", stype) + return None + -# Global storage service instance -storage_service = ( - AzureStorageService() if settings.AZURE_STORAGE_ACCOUNT_NAME else None -) +storage_service: Optional[StorageService] = _build_storage_service() diff --git a/backend/api/services/user_db.py b/backend/api/services/user_db.py index 00300523..0624b084 100644 --- a/backend/api/services/user_db.py +++ b/backend/api/services/user_db.py @@ -1,102 +1,81 @@ -"""User database service using Azure Blob Storage""" +"""backend.api.services.user_db + +User database service. + +Historically this project stored users inside Azure Blob Storage directly. +To support multiple storage backends (local / azure / entra) we now build the +user DB on top of the common `storage_service` abstraction. + +Storage keys used: + system/user_registry.json + +This file intentionally avoids importing Azure SDK packages so that local +deployments can run without them. +""" + +from __future__ import annotations import json import uuid from datetime import datetime -from typing import Dict, List, Optional, Any +from typing import Any, Dict, List, Optional -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient -from azure.core.exceptions import ResourceNotFoundError from passlib.context import CryptContext -from ..core.config import settings from ..models.auth import UserCreate, UserRead +from .storage import storage_service class UserDatabaseService: - """Service for managing user data in Azure Blob Storage""" + """Service for managing user data in the configured storage backend.""" def __init__(self): - if not settings.AZURE_STORAGE_ACCOUNT_NAME and not settings.AZURE_STORAGE_CONNECTION_STRING: - raise ValueError("AZURE_STORAGE_ACCOUNT_NAME or AZURE_STORAGE_CONNECTION_STRING must be configured") - - if settings.AZURE_STORAGE_ACCOUNT_NAME: - account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" - credential = DefaultAzureCredential() - self.blob_service_client = BlobServiceClient( - account_url=account_url, credential=credential - ) - elif settings.AZURE_STORAGE_CONNECTION_STRING: - self.blob_service_client = BlobServiceClient.from_connection_string( - settings.AZURE_STORAGE_CONNECTION_STRING + if not storage_service: + raise RuntimeError( + "Storage is not configured. User database is unavailable." ) - - self.container_name = settings.AZURE_STORAGE_CONTAINER_NAME + self.storage = storage_service self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") - self._ensure_container_exists() - - def _ensure_container_exists(self): - """Ensure the storage container exists""" - try: - self.blob_service_client.create_container(self.container_name) - except Exception: - pass - def _get_password_hash(self, password: str) -> str: - """Get password hash""" return self.pwd_context.hash(password) def _verify_password(self, plain_password: str, hashed_password: str) -> bool: - """Verify a password against a hash""" return self.pwd_context.verify(plain_password, hashed_password) + def _registry_path(self) -> str: + # Keep legacy layout: /system/user_registry.json + return f"{self.storage.container_name}/system/user_registry.json" + async def _load_user_registry(self) -> Dict[str, Any]: - """Load the user registry from blob storage""" + """Load the user registry from storage.""" try: - blob_name = "system/user_registry.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name - ) - - blob_data = blob_client.download_blob().readall() - return json.loads(blob_data.decode("utf-8")) - except ResourceNotFoundError: - # Create empty registry if it doesn't exist - return {"users": {}, "email_index": {}} - except Exception as e: - print(f"Error loading user registry: {e}") + content, _filename = await self.storage.get_bytes_by_path(self._registry_path()) + return json.loads(content.decode("utf-8")) + except Exception: + # Create empty registry if it doesn't exist / cannot be read return {"users": {}, "email_index": {}} async def _save_user_registry(self, registry: Dict[str, Any]) -> bool: - """Save the user registry to blob storage""" + """Save the user registry to storage.""" try: - blob_name = "system/user_registry.json" - blob_client = self.blob_service_client.get_blob_client( - container=self.container_name, blob=blob_name + payload = json.dumps(registry, indent=2).encode("utf-8") + return await self.storage.put_bytes_by_path( + self._registry_path(), + payload, + content_type="application/json", ) - - registry_json = json.dumps(registry, indent=2) - blob_client.upload_blob(registry_json, overwrite=True) - return True - except Exception as e: - print(f"Error saving user registry: {e}") + except Exception: return False async def create_user(self, user_data: UserCreate) -> Optional[UserRead]: - """Create a new user""" try: registry = await self._load_user_registry() - # Check if user already exists if user_data.email in registry["email_index"]: - return None # User already exists + return None - # Generate unique user ID user_id = str(uuid.uuid4()) - - # Create user record user_record = { "id": user_id, "email": user_data.email, @@ -109,89 +88,59 @@ async def create_user(self, user_data: UserCreate) -> Optional[UserRead]: "last_login": None, } - # Add to registry registry["users"][user_id] = user_record registry["email_index"][user_data.email] = user_id - # Save registry - if await self._save_user_registry(registry): - # Create user directory structure in storage - from .storage import storage_service - # from .dual_milvus_manager import dual_milvus_manager - # from .base_knowledge_manager import base_knowledge_manager - - if storage_service: - await storage_service.create_user_directory(user_id) - - # User collections will be created on first login via auth router - - # Note: Base knowledge is initialized once at startup, not per user + if not await self._save_user_registry(registry): + return None - return UserRead( - id=user_id, - email=user_record["email"], - full_name=user_record["full_name"], - is_active=user_record["is_active"], - is_superuser=user_record["is_superuser"], - created_at=user_record["created_at"], - ) + await self.storage.create_user_directory(user_id) - return None - except Exception as e: - print(f"Error creating user: {e}") + return UserRead( + id=user_id, + email=user_record["email"], + full_name=user_record["full_name"], + is_active=user_record["is_active"], + is_superuser=user_record["is_superuser"], + created_at=user_record["created_at"], + ) + except Exception: return None async def get_user_by_email(self, email: str) -> Optional[Dict[str, Any]]: - """Get a user by email""" try: registry = await self._load_user_registry() - if email not in registry["email_index"]: return None - user_id = registry["email_index"][email] return registry["users"].get(user_id) - except Exception as e: - print(f"Error getting user by email: {e}") + except Exception: return None async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]: - """Get a user by ID""" try: registry = await self._load_user_registry() return registry["users"].get(user_id) - except Exception as e: - print(f"Error getting user by ID: {e}") + except Exception: return None - async def authenticate_user( - self, email: str, password: str, sso: bool - ) -> Optional[Dict[str, Any]]: - """Authenticate a user""" + async def authenticate_user(self, email: str, password: str, sso: bool) -> Optional[Dict[str, Any]]: try: user = await self.get_user_by_email(email) if not user: return None - if not sso and not self._verify_password(password, user["hashed_password"]): return None - # Update last login user["last_login"] = datetime.utcnow().isoformat() await self.update_user(user["id"], {"last_login": user["last_login"]}) - return user - except Exception as e: - print(f"Error authenticating user: {e}") + except Exception: return None - async def update_user( - self, user_id: str, update_data: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """Update user data""" + async def update_user(self, user_id: str, update_data: Dict[str, Any]) -> Optional[Dict[str, Any]]: try: registry = await self._load_user_registry() - if user_id not in registry["users"]: return None @@ -201,61 +150,35 @@ async def update_user( if await self._save_user_registry(registry): return user_record - return None - except Exception as e: - print(f"Error updating user: {e}") + except Exception: return None async def deactivate_user(self, user_id: str) -> bool: - """Deactivate a user""" - try: - result = await self.update_user(user_id, {"is_active": False}) - return result is not None - except Exception as e: - print(f"Error deactivating user: {e}") - return False + result = await self.update_user(user_id, {"is_active": False}) + return result is not None async def list_users(self, skip: int = 0, limit: int = 100) -> List[Dict[str, Any]]: - """List all users (for admin purposes)""" - try: - registry = await self._load_user_registry() - users = list(registry["users"].values()) - - users.sort(key=lambda x: x.get("created_at", ""), reverse=True) - - return users[skip : skip + limit] - except Exception as e: - print(f"Error listing users: {e}") - return [] + registry = await self._load_user_registry() + users = list(registry.get("users", {}).values()) + users.sort(key=lambda x: x.get("created_at", ""), reverse=True) + return users[skip : skip + limit] async def get_all_users(self) -> List[Dict[str, Any]]: - """Get all users (for admin purposes)""" - try: - registry = await self._load_user_registry() - users = list(registry["users"].values()) - - users.sort(key=lambda x: x.get("created_at", ""), reverse=True) - - return users - except Exception as e: - print(f"Error getting all users: {e}") - return [] + registry = await self._load_user_registry() + users = list(registry.get("users", {}).values()) + users.sort(key=lambda x: x.get("created_at", ""), reverse=True) + return users async def get_user_count(self) -> int: - """Get total number of users""" - try: - registry = await self._load_user_registry() - return len(registry["users"]) - except Exception as e: - print(f"Error getting user count: {e}") - return 0 + registry = await self._load_user_registry() + return len(registry.get("users", {})) -# Global user database service instance -user_db_service = ( - UserDatabaseService() if settings.AZURE_STORAGE_ACCOUNT_NAME else None -) +# Global instance +try: + user_db_service: Optional[UserDatabaseService] = UserDatabaseService() +except Exception: + user_db_service = None -# Alias for backward compatibility user_db = user_db_service diff --git a/backend/api/sr/router.py b/backend/api/sr/router.py index 1db2b27e..e36b095d 100644 --- a/backend/api/sr/router.py +++ b/backend/api/sr/router.py @@ -22,34 +22,12 @@ from ..core.config import settings from ..core.security import get_current_active_user -from ..services.user_db import user_db as user_db_service +from ..services.user_db import user_db_service from ..services.sr_db_service import srdb_service from ..core.cit_utils import load_sr_and_check router = APIRouter() -# Helper to get database connection string -def _get_db_conn_str() -> Optional[str]: - """ - Get database connection string for PostgreSQL. - - If POSTGRES_URI is set, returns it directly (local development). - If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER), - returns None to signal that connect_postgres() should use Entra ID authentication. - """ - if settings.POSTGRES_URI: - return settings.POSTGRES_URI - - # If Entra ID config is available, return None to let connect_postgres use token auth - if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER: - return None - - # No configuration available - raise ValueError( - "PostgreSQL not configured. Set POSTGRES_URI for local development, " - "or POSTGRES_HOST, POSTGRES_DATABASE, and POSTGRES_USER for Entra ID authentication." - ) - class SystematicReviewCreate(BaseModel): name: str description: Optional[str] = None @@ -98,7 +76,6 @@ async def create_systematic_review( One of criteria_file or criteria_yaml may be provided. If both are provided, criteria_file takes precedence. The created SR is stored in PostgreSQL and the creating user is added as the first member. """ - db_conn_str = _get_db_conn_str() if not name or not name.strip(): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name is required") @@ -134,7 +111,6 @@ async def create_systematic_review( try: sr_doc = await run_in_threadpool( srdb_service.create_systematic_review, - db_conn_str, name, description, criteria_str, @@ -187,9 +163,8 @@ async def add_user_to_systematic_review( The endpoint checks that the requester is a member of the SR. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -203,7 +178,7 @@ async def add_user_to_systematic_review( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email") try: - res = await run_in_threadpool(srdb_service.add_user, db_conn_str, sr_id, target_user_id, current_user.get("id")) + res = await run_in_threadpool(srdb_service.add_user, sr_id, target_user_id, current_user.get("id")) except HTTPException: raise except Exception as e: @@ -225,10 +200,8 @@ async def remove_user_from_systematic_review( The endpoint checks that the requester is a member of the SR (or owner). The owner cannot be removed via this endpoint. """ - - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -246,7 +219,7 @@ async def remove_user_from_systematic_review( raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove the owner from the systematic review") try: - res = await run_in_threadpool(srdb_service.remove_user, db_conn_str, sr_id, target_user_id, current_user.get("id")) + res = await run_in_threadpool(srdb_service.remove_user, sr_id, target_user_id, current_user.get("id")) except HTTPException: raise except Exception as e: @@ -263,12 +236,11 @@ async def list_systematic_reviews_for_user( List all systematic reviews the current user has access to (is a member of). Hidden/deleted SRs (visible == False) are excluded. """ - db_conn_str = _get_db_conn_str() user_id = current_user.get("email") results = [] try: - docs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, db_conn_str, user_id) + docs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, user_id) except HTTPException: raise except Exception as e: @@ -301,9 +273,8 @@ async def get_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depen Get a single systematic review by id. User must be a member to view. """ - db_conn_str = _get_db_conn_str() try: - doc, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -336,9 +307,8 @@ async def get_systematic_review_criteria_parsed( Returns an empty dict if no parsed criteria are available. """ - db_conn_str = _get_db_conn_str() try: - doc, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -363,9 +333,8 @@ async def update_systematic_review_criteria( The parsed criteria (dict) and the raw YAML are both saved to the SR document. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -402,7 +371,7 @@ async def update_systematic_review_criteria( # perform update try: - doc = await run_in_threadpool(srdb_service.update_criteria, db_conn_str, sr_id, criteria_obj, criteria_str, current_user.get("id")) + doc = await run_in_threadpool(srdb_service.update_criteria, sr_id, criteria_obj, criteria_str, current_user.get("id")) except HTTPException: raise except Exception as e: @@ -432,9 +401,8 @@ async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = De Only the owner may delete a systematic review. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) except HTTPException: raise except Exception as e: @@ -445,7 +413,7 @@ async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = De raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may delete this systematic review") try: - res = await run_in_threadpool(srdb_service.soft_delete_systematic_review, db_conn_str, sr_id, requester_id) + res = await run_in_threadpool(srdb_service.soft_delete_systematic_review, sr_id, requester_id) except HTTPException: raise except Exception as e: @@ -462,9 +430,8 @@ async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Only the owner may undelete a systematic review. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False, require_visible=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False) except HTTPException: raise except Exception as e: @@ -475,7 +442,7 @@ async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may undelete this systematic review") try: - res = await run_in_threadpool(srdb_service.undelete_systematic_review, db_conn_str, sr_id, requester_id) + res = await run_in_threadpool(srdb_service.undelete_systematic_review, sr_id, requester_id) except HTTPException: raise except Exception as e: @@ -495,9 +462,8 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] Only the owner may perform a hard delete. """ - db_conn_str = _get_db_conn_str() try: - sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False, require_visible=False) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False) except HTTPException: raise except Exception as e: @@ -526,7 +492,7 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] cleanup_result = {"status": "cleanup_import_failed", "error": str(e)} try: - res = await run_in_threadpool(srdb_service.hard_delete_systematic_review, db_conn_str, sr_id, requester_id) + res = await run_in_threadpool(srdb_service.hard_delete_systematic_review, sr_id, requester_id) deleted_count = res.get("deleted_count") if not deleted_count: # If backend reported zero deletions, raise NotFound to match prior behavior diff --git a/backend/configs/models.yaml b/backend/configs/models.yaml new file mode 100644 index 00000000..75953bc9 --- /dev/null +++ b/backend/configs/models.yaml @@ -0,0 +1,7 @@ +GPT-4.1-Mini: + deployment: gpt-4.1-mini + api_version: 2025-01-01-preview + +GPT-5-Mini: + deployment: gpt-5-mini + api_version: 2025-04-01-preview diff --git a/backend/deploy.sh b/backend/deploy.sh index 6de15442..1f769e6d 100755 --- a/backend/deploy.sh +++ b/backend/deploy.sh @@ -95,6 +95,7 @@ fi # Create necessary directories echo -e "${BLUE}📁 Creating volume directories...${NC}" mkdir -p volumes/{postgres-cits} +mkdir -p uploads/users echo -e "${GREEN}🚀 Starting services...${NC}" diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index 76b9fb6a..3cca1cbd 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -11,7 +11,6 @@ services: - "8000:8000" environment: - GROBID_SERVICE_URL=http://grobid-service:8070 - - POSTGRES_URI=postgres://admin:password@pgdb-service:5432/postgres env_file: - .env depends_on: @@ -19,6 +18,7 @@ services: - pgdb-service volumes: - ./api:/app/api + - ./configs:/app/configs - ./uploads:/app/uploads healthcheck: test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"] @@ -38,10 +38,10 @@ services: restart: unless-stopped healthcheck: test: ["CMD", "curl", "-f", "http://localhost:8070/api/isalive"] - interval: 30s + interval: 10s timeout: 10s retries: 5 - start_period: 30s + start_period: 120s # ============================================================================= # POSTGRESQL - Database (Citations & Systematic Reviews) diff --git a/backend/main.py b/backend/main.py index e89158e2..4173c9f1 100644 --- a/backend/main.py +++ b/backend/main.py @@ -31,17 +31,10 @@ async def startup_event(): print("📚 Initializing systematic review database...", flush=True) # Ensure systematic review table exists in PostgreSQL try: - # Check if Entra ID or POSTGRES_URI is configured - has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER - has_uri_config = settings.POSTGRES_URI - - if has_entra_config or has_uri_config: - # Pass connection string if available, otherwise None for Entra ID auth - db_conn_str = settings.POSTGRES_URI if has_uri_config else None - await run_in_threadpool(srdb_service.ensure_table_exists, db_conn_str) - print("✓ Systematic review table initialized", flush=True) - else: - print("⚠️ PostgreSQL not configured - skipping SR table initialization", flush=True) + # POSTGRES_URI is deprecated; the DB connection is handled by postgres_server + # using POSTGRES_MODE/POSTGRES_* settings. + await run_in_threadpool(srdb_service.ensure_table_exists) + print("✓ Systematic review table initialized", flush=True) except Exception as e: print(f"⚠️ Failed to ensure SR table exists: {e}", flush=True) print("🎯 CAN-SR Backend ready!", flush=True) @@ -88,8 +81,9 @@ async def health_check(): "status": "ok", "service": settings.PROJECT_NAME, "version": settings.VERSION, - "storage_type": settings.STORAGE_TYPE, - "azure_storage_configured": bool(settings.AZURE_STORAGE_CONNECTION_STRING), + "azure_openai_mode": settings.AZURE_OPENAI_MODE, + "postgres_mode": settings.POSTGRES_MODE, + "storage_mode": settings.STORAGE_MODE, "azure_openai_configured": azure_openai_client.is_configured(), "default_chat_model": settings.DEFAULT_CHAT_MODEL, "available_models": azure_openai_client.get_available_models(), diff --git a/backend/requirements.txt b/backend/requirements.txt index 517a0efa..39a7ab02 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -3,6 +3,7 @@ fastapi>=0.104.1 uvicorn>=0.23.2 pydantic>=2.7.0 pydantic-settings>=2.3.0 +PyYAML==6.0.2 email-validator>=2.0.0 python-multipart==0.0.22 python-dotenv>=1.0.1 @@ -16,6 +17,8 @@ bcrypt==4.0.1 # Azure Services azure-storage-blob==12.25.1 azure-identity==1.23.0 +azure-ai-documentintelligence==1.0.2 +azure-core==1.32.0 openai==2.15.0 # Databases diff --git a/frontend/app/can-sr/extract/view/page.tsx b/frontend/app/can-sr/extract/view/page.tsx index 88fd7a69..2beb1563 100644 --- a/frontend/app/can-sr/extract/view/page.tsx +++ b/frontend/app/can-sr/extract/view/page.tsx @@ -66,6 +66,22 @@ export default function CanSrL2ScreenPage() { const [fulltextStr, setFulltextStr] = useState(null) const viewerRef = useRef(null) + // Table/Figure artifacts (for evidence chips -> click to highlight) + const [fulltextTables, setFulltextTables] = useState(null) + const [fulltextFigures, setFulltextFigures] = useState(null) + + const scrollToArtifact = (kind: 'table' | 'figure', idx: number) => { + const list = kind === 'table' ? (fulltextTables || []) : (fulltextFigures || []) + const item = list.find((x: any) => Number(x?.index) === Number(idx)) + console.log('[artifact-click]', { kind, idx, hasViewer: !!viewerRef.current, item }) + if (!item || !viewerRef.current) return + const bbox = item?.bounding_box + const first = Array.isArray(bbox) ? bbox[0] : null + console.log('[artifact-bbox]', { kind, idx, bbox, first }) + if (!first) return + viewerRef.current.scrollToCoord(first) + } + const [runningAllAI, setRunningAllAI] = useState(false) const [runAllProgress, setRunAllProgress] = useState<{ done: number; total: number } | null>(null) @@ -297,10 +313,16 @@ export default function CanSrL2ScreenPage() { setAiPanels(prev => ({ ...prev, ...nextAIPanels })) } - // extract coords/pages/fulltext for PDF overlay + // extract coords/pages/fulltext and artifacts for PDF overlay const ft = typeof (row as any).fulltext === 'string' ? (row as any).fulltext : null if (ft) setFulltextStr(ft) + const tablesAny = parseJson((row as any).fulltext_tables) ?? (row as any).fulltext_tables + if (tablesAny && Array.isArray(tablesAny)) setFulltextTables(tablesAny) + + const figsAny = parseJson((row as any).fulltext_figures) ?? (row as any).fulltext_figures + if (figsAny && Array.isArray(figsAny)) setFulltextFigures(figsAny) + const coordsAny = parseJson((row as any).fulltext_coords) ?? (row as any).fulltext_coords if (coordsAny && Array.isArray(coordsAny)) setFulltextCoords(coordsAny) @@ -339,6 +361,12 @@ export default function CanSrL2ScreenPage() { const ft = typeof (row as any).fulltext === 'string' ? (row as any).fulltext : null if (ft) setFulltextStr(ft) + const tablesAny = parseJson((row as any).fulltext_tables) ?? (row as any).fulltext_tables + if (tablesAny && Array.isArray(tablesAny)) setFulltextTables(tablesAny) + + const figsAny = parseJson((row as any).fulltext_figures) ?? (row as any).fulltext_figures + if (figsAny && Array.isArray(figsAny)) setFulltextFigures(figsAny) + const coordsAny = parseJson((row as any).fulltext_coords) ?? (row as any).fulltext_coords if (coordsAny && Array.isArray(coordsAny)) setFulltextCoords(coordsAny) @@ -658,6 +686,50 @@ export default function CanSrL2ScreenPage() { ) : null} + +{Array.isArray(aiPanels[paramName]?.evidence_tables) && aiPanels[paramName].evidence_tables.length > 0 ? ( +
+ Evidence tables: +
+ {aiPanels[paramName].evidence_tables.map((t: any, k: number) => { + const label = `Table T${String(t)}` + return ( + + ) + })} +
+
+) : null} + +{Array.isArray(aiPanels[paramName]?.evidence_figures) && aiPanels[paramName].evidence_figures.length > 0 ? ( +
+ Evidence figures: +
+ {aiPanels[paramName].evidence_figures.map((f: any, k: number) => { + const label = `Figure F${String(f)}` + return ( + + ) + })} +
+
+) : null} ) : null} diff --git a/frontend/app/can-sr/l2-screen/view/page.tsx b/frontend/app/can-sr/l2-screen/view/page.tsx index 9209b136..f6615dad 100644 --- a/frontend/app/can-sr/l2-screen/view/page.tsx +++ b/frontend/app/can-sr/l2-screen/view/page.tsx @@ -508,6 +508,44 @@ export default function CanSrL2ScreenViewPage() { return { panels: mappedPanels, open: mappedOpen } }, [criteriaData, aiPanels, panelOpen]) + // Helpers for table/figure evidence -> viewer highlight + const parsedTables = useMemo(() => { + if (!citation) return [] as any[] + let v: any = (citation as any).fulltext_tables + if (!v) return [] + try { + if (typeof v === 'string') v = JSON.parse(v) + } catch { + // ignore + } + return Array.isArray(v) ? v : [] + }, [citation]) + + const parsedFigures = useMemo(() => { + if (!citation) return [] as any[] + let v: any = (citation as any).fulltext_figures + if (!v) return [] + try { + if (typeof v === 'string') v = JSON.parse(v) + } catch { + // ignore + } + return Array.isArray(v) ? v : [] + }, [citation]) + + const scrollToArtifact = (kind: 'table' | 'figure', idx: number) => { + const list = kind === 'table' ? parsedTables : parsedFigures + const item = list.find((x: any) => Number(x?.index) === Number(idx)) + console.log('[artifact-click]', { kind, idx, hasViewer: !!viewerRef.current, item }) + if (!item || !viewerRef.current) return + const bbox = item?.bounding_box + // We store normalized boxes as an array of {page,x,y,width,height} + const first = Array.isArray(bbox) ? bbox[0] : null + console.log('[artifact-bbox]', { kind, idx, bbox, first }) + if (!first) return + viewerRef.current.scrollToCoord(first) + } + const workspace = useMemo(() => { if (loadingCitation) return
Loading citation...
@@ -707,6 +745,50 @@ export default function CanSrL2ScreenViewPage() { ) : null} + + {Array.isArray(aiData?.evidence_tables) && aiData.evidence_tables.length > 0 ? ( +
+ Evidence tables: +
+ {aiData.evidence_tables.map((t: any, k: number) => { + const label = `Table T${String(t)}` + return ( + + ) + })} +
+
+ ) : null} + + {Array.isArray(aiData?.evidence_figures) && aiData.evidence_figures.length > 0 ? ( +
+ Evidence figures: +
+ {aiData.evidence_figures.map((f: any, k: number) => { + const label = `Figure F${String(f)}` + return ( + + ) + })} +
+
+ ) : null} ) : null} diff --git a/frontend/components/can-sr/PDFBoundingBoxViewer.tsx b/frontend/components/can-sr/PDFBoundingBoxViewer.tsx index beb6c206..94d52856 100644 --- a/frontend/components/can-sr/PDFBoundingBoxViewer.tsx +++ b/frontend/components/can-sr/PDFBoundingBoxViewer.tsx @@ -123,6 +123,9 @@ const wrapperRefs = useRef>({}) const renderTasksRef = useRef>({}) const renderTokenRef = useRef(0) const [hoverInfo, setHoverInfo] = useState<{ page: number; left: number; top: number; content: string } | null>(null) + // Explicitly selected box (e.g., when user clicks an evidence chip). This is drawn + // regardless of whether the LLM evidence panels contain that coordinate. + const [selectedCoord, setSelectedCoord] = useState(null) const sentenceTexts = extractSentenceArray(fulltext) @@ -138,6 +141,7 @@ useImperativeHandle(ref, () => ({ }, scrollToCoord: (coord: any) => { try { + setSelectedCoord(coord) const pageNum = Number(coord?.page ?? coord?.page_number ?? coord?.pageNum ?? 1) const vp = pageViewports[pageNum] const dims = pages?.[pageNum - 1] @@ -166,6 +170,7 @@ useImperativeHandle(ref, () => ({ const firstCoord = Array.isArray(coords) ? coords.find((c: any) => String(c?.text || '').trim() === trimmed) : null if (!firstCoord) return + setSelectedCoord(firstCoord) const pageNum = Number(firstCoord?.page ?? firstCoord?.page_number ?? firstCoord?.pageNum ?? 1) const vp = pageViewports[pageNum] const dims = pages?.[pageNum - 1] @@ -595,6 +600,18 @@ useImperativeHandle(ref, () => ({ }) : [] + // Add an explicitly selected coordinate (e.g., table/figure chip click) + if (selectedCoord) { + try { + const p = Number(selectedCoord?.page ?? selectedCoord?.page_number ?? selectedCoord?.pageNum ?? 0) + if (p === pageNum) { + filtered.unshift({ ...selectedCoord, __selected: true }) + } + } catch { + // ignore + } + } + const elements = filtered.map((c: any, idx: number) => { const x = parseFloat(c?.x ?? '0') @@ -651,10 +668,25 @@ useImperativeHandle(ref, () => ({ (isOpen ? (paramsOpenHere[0] || coordsParamsOpen[0]) : undefined) ?? (isClosed ? (paramsClosedHere[0] || coordsParamsClosed[0]) : undefined) - const alpha = isOpen ? 0.2 : 0.05 - const fill = chosenParam ? colorForParam(chosenParam, alpha) : `rgba(255, 229, 100, ${alpha})` - const borderColor = chosenParam ? solidForParam(chosenParam) : 'rgba(255, 196, 0, 0.95)' - const border = isOpen ? `2px solid ${borderColor}` : `1px dashed ${borderColor}` + const isSelected = !!c?.__selected + const isArtifact = c?.type === 'table' || c?.type === 'figure' + // Keep tables/figures highlights more transparent so users can still read/see + // the underlying content. + const alpha = isArtifact + ? (isSelected ? 0.5 : (isOpen ? 0.5 : 0.5)) + : (isSelected ? 0.0 : (isOpen ? 0.0 : 0.0)) + const border_alpha = isArtifact + ? 0.95 + : 0.0 + const fill = isSelected + ? `rgba(59, 130, 246, ${alpha})` + : (chosenParam ? colorForParam(chosenParam, alpha) : `rgba(255, 229, 100, ${alpha})`) + const borderColor = isSelected + ? `rgba(37, 99, 235, ${border_alpha})` + : (chosenParam ? solidForParam(chosenParam) : `rgba(255, 196, 0, ${border_alpha})`) + const border = isSelected + ? `3px solid ${borderColor}` + : (isOpen ? `2px solid ${borderColor}` : `1px dashed ${borderColor}`) const title = chosenParam ? `${chosenParam}${t ? `: ${t.slice(0, 160)}` : ''}` : t ? t.slice(0, 160) : 'Sentence' diff --git a/frontend/components/can-sr/stacking-card.tsx b/frontend/components/can-sr/stacking-card.tsx index 3cf78543..1a73eac6 100644 --- a/frontend/components/can-sr/stacking-card.tsx +++ b/frontend/components/can-sr/stacking-card.tsx @@ -36,10 +36,10 @@ export default function StackingCard({ title, description, href, className }: St {open ? 'Minimize' : 'Expand'} */} - - + + Open - + @@ -58,5 +58,5 @@ export default function StackingCard({ title, description, href, className }: St - ) + ); } diff --git a/frontend/public/images/backgrounds/homepage.jpg b/frontend/public/images/backgrounds/homepage.jpg index 7f9b249f..3fa1aac2 100644 Binary files a/frontend/public/images/backgrounds/homepage.jpg and b/frontend/public/images/backgrounds/homepage.jpg differ diff --git a/frontend/public/images/backgrounds/homepage_hc.jpg b/frontend/public/images/backgrounds/homepage_hc.jpg new file mode 100644 index 00000000..7f9b249f Binary files /dev/null and b/frontend/public/images/backgrounds/homepage_hc.jpg differ diff --git a/frontend/tsconfig.json b/frontend/tsconfig.json index e7ff3a26..74f65609 100644 --- a/frontend/tsconfig.json +++ b/frontend/tsconfig.json @@ -15,7 +15,7 @@ "moduleResolution": "bundler", "resolveJsonModule": true, "isolatedModules": true, - "jsx": "react-jsx", + "jsx": "preserve", "incremental": true, "plugins": [ {