Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,7 @@ frontend/app/can-sr/setup/test.yaml
backend/api/citations/test.ipynb
file1.csv
*/criteria_config_measles_updated.yaml
AGENTS_ROADMAP.md
AGENTS_ROADMAP.md

*/logfile
logfile
1 change: 1 addition & 0 deletions backend/.dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ dmypy.json
.DS_Store
deploy.sh
*.sh
!entrypoint.sh
*.pem
16 changes: 14 additions & 2 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,26 @@ RUN mkdir -p uploads

# Create non-root user for security
RUN useradd -m -u 1001 appuser && chown -R appuser:appuser /app
USER appuser


COPY sshd_config /etc/ssh/
COPY entrypoint.sh /entrypoint.sh

RUN apt-get update \
&& apt-get install -y --no-install-recommends dialog \
&& apt-get install -y --no-install-recommends openssh-server \
&& echo "root:Docker!" | chpasswd \
&& chmod u+x /entrypoint.sh

USER root

# Expose port
EXPOSE 8000
EXPOSE 8000 2222

# Health check
HEALTHCHECK --interval=30s --timeout=30s --start-period=5s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"

# Run the application
ENTRYPOINT ["/entrypoint.sh"]
CMD ["python", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
65 changes: 45 additions & 20 deletions backend/api/citations/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel


from ..services.sr_db_service import srdb_service

from ..core.security import get_current_active_user
Expand All @@ -38,6 +37,34 @@
router = APIRouter()


def _get_db_conn_str() -> Optional[str]:
"""
Get database connection string for PostgreSQL.

If POSTGRES_URI is set, returns it directly (local development).
If Entra ID env variables are configured (POSTGRES_HOST, POSTGRES_DATABASE, POSTGRES_USER),
returns None to signal that connect_postgres() should use Entra ID authentication.
"""
if settings.POSTGRES_URI:
return settings.POSTGRES_URI

# If Entra ID config is available, return None to let connect_postgres use token auth
if settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER:
return None

# No configuration available - return None, let downstream handle the error
return None


def _is_postgres_configured() -> bool:
"""
Check if PostgreSQL is configured via Entra ID env vars or connection string.
"""
has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER
has_uri_config = settings.POSTGRES_URI
return bool(has_entra_config or has_uri_config)


class UploadResult(BaseModel):
sr_id: str
table_name: str
Expand Down Expand Up @@ -75,21 +102,21 @@ async def upload_screening_csv(
- The SR must exist and the user must be a member of the SR (or owner).
"""

db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, _ = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service, require_screening=False)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")

# Shared DB connection string
db_conn = settings.POSTGRES_URI
if not db_conn:
# Check admin DSN (use centralized settings) - need either Entra ID config or POSTGRES_URI
if not _is_postgres_configured():
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="PostgreSQL connection not configured. Set POSTGRES_URI in configuration/environment.",
detail="Postgres not configured. Set POSTGRES_HOST/DATABASE/USER for Entra ID auth, or POSTGRES_URI for local dev.",
)
admin_dsn = _get_db_conn_str()

# Read CSV content
include_columns = None
Expand All @@ -114,14 +141,14 @@ async def upload_screening_csv(
try:
old = (sr.get("screening_db") or {}).get("table_name")
if old:
await run_in_threadpool(cits_dp_service.drop_table, db_conn, old)
await run_in_threadpool(cits_dp_service.drop_table, admin_dsn, old)
except Exception:
# best-effort only
pass

# Create table and insert rows in threadpool
try:
inserted = await run_in_threadpool(_create_table_and_insert_sync, db_conn, table_name, include_columns, normalized_rows)
inserted = await run_in_threadpool(_create_table_and_insert_sync, admin_dsn, table_name, include_columns, normalized_rows)
except RuntimeError as rexc:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
except Exception as e:
Expand All @@ -131,7 +158,7 @@ async def upload_screening_csv(
try:
screening_info = {
"screening_db": {
"connection_string": db_conn,
"connection_string": admin_dsn,
"table_name": table_name,
"created_at": datetime.utcnow().isoformat(),
"rows": inserted,
Expand All @@ -141,7 +168,7 @@ async def upload_screening_csv(
# Update SR document with screening DB info using PostgreSQL
await run_in_threadpool(
srdb_service.update_screening_db_info,
settings.POSTGRES_URI,
_get_db_conn_str(),
sr_id,
screening_info["screening_db"]
)
Expand Down Expand Up @@ -171,7 +198,7 @@ async def list_citation_ids(

Returns a simple list of integers (the 'id' primary key from the citations table).
"""
db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service)
except HTTPException:
Expand Down Expand Up @@ -210,7 +237,7 @@ async def get_citation_by_id(
Returns: a JSON object representing the citation row (keys are DB column names).
"""

db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service)
except HTTPException:
Expand Down Expand Up @@ -263,7 +290,7 @@ async def build_combined_citation(
the format "<ColumnName>: <value> \\n" for each included column, in the order provided.
"""

db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service)
except HTTPException:
Expand Down Expand Up @@ -318,7 +345,7 @@ async def upload_citation_fulltext(
to the storage path (container/blob).
"""

db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service)
except HTTPException:
Expand Down Expand Up @@ -414,7 +441,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
- Caller must be the SR owner.
"""

db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, db_conn = await load_sr_and_check(sr_id, current_user, db_conn_str, srdb_service)
except HTTPException:
Expand All @@ -429,9 +456,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
if not screening:
return {"status": "no_screening_db", "message": "No screening table configured for this SR", "deleted_table": False, "deleted_files": 0}

db_conn = screening.get("connection_string")
if not db_conn:
return {"status": "no_screening_db", "message": "Incomplete screening DB metadata", "deleted_table": False, "deleted_files": 0}
db_conn = None

table_name = screening.get("table_name")
if not table_name:
Expand Down Expand Up @@ -521,7 +546,7 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
try:
await run_in_threadpool(
srdb_service.clear_screening_db_info,
settings.POSTGRES_URI,
_get_db_conn_str(),
sr_id
)
except Exception:
Expand Down Expand Up @@ -557,7 +582,7 @@ async def export_citations_csv(
Content-Disposition.
"""

db_conn_str = settings.POSTGRES_URI
db_conn_str = _get_db_conn_str()
try:
sr, screening, db_conn = await load_sr_and_check(
sr_id, current_user, db_conn_str, srdb_service
Expand Down
32 changes: 15 additions & 17 deletions backend/api/core/cit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,22 @@
from fastapi import HTTPException, status
from fastapi.concurrency import run_in_threadpool

from .config import settings


def _is_postgres_configured(db_conn_str: Optional[str] = None) -> bool:
"""
Check if PostgreSQL is configured via Entra ID env vars or connection string.
"""
has_entra_config = settings.POSTGRES_HOST and settings.POSTGRES_DATABASE and settings.POSTGRES_USER
has_uri_config = db_conn_str or settings.POSTGRES_URI
return bool(has_entra_config or has_uri_config)


async def load_sr_and_check(
sr_id: str,
current_user: Dict[str, Any],
db_conn_str: str,
db_conn_str: Optional[str],
srdb_service,
require_screening: bool = True,
require_visible: bool = True,
Expand All @@ -27,7 +39,7 @@ async def load_sr_and_check(
Args:
sr_id: SR id string
current_user: current user dict (must contain "id" and "email")
db_conn_str: PostgreSQL connection string
db_conn_str: PostgreSQL connection string (can be None if using Entra ID auth)
srdb_service: SR DB service instance (must implement get_systematic_review and user_has_sr_permission)
require_screening: if True, also ensure the SR has a configured screening_db and return its connection string
require_visible: if True, require the SR 'visible' flag to be True; set False for endpoints like hard-delete
Expand All @@ -37,18 +49,6 @@ async def load_sr_and_check(

Raises HTTPException with appropriate status codes on failure so routers can just propagate.
"""
# ensure DB helper present and call it
if not db_conn_str:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Server misconfiguration: PostgreSQL connection not available",
)
try:
await run_in_threadpool(srdb_service.ensure_db_available, db_conn_str)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(e))

# fetch SR
try:
Expand Down Expand Up @@ -81,8 +81,6 @@ async def load_sr_and_check(
if require_screening:
if not screening:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review")
db_conn = screening.get("connection_string")
if not db_conn:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Screening DB connection info missing")
db_conn = None

return sr, screening, db_conn
18 changes: 16 additions & 2 deletions backend/api/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class Settings(BaseSettings):

# Storage settings
STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "azure")
AZURE_STORAGE_ACCOUNT_NAME: Optional[str] = os.getenv(
"AZURE_STORAGE_ACCOUNT_NAME"
)
AZURE_STORAGE_CONNECTION_STRING: Optional[str] = os.getenv(
"AZURE_STORAGE_CONNECTION_STRING"
)
Expand Down Expand Up @@ -108,8 +111,16 @@ def convert_max_file_size(cls, v):
DEBUG: bool = os.getenv("DEBUG", "false").lower() == "true"

# Database and external system environment variables
# Postgres DSN used for systematic reviews and screening databases
POSTGRES_URI: str = os.getenv("POSTGRES_URI")
# Postgres settings for Entra ID authentication
POSTGRES_HOST: Optional[str] = os.getenv("POSTGRES_HOST")
POSTGRES_DATABASE: Optional[str] = os.getenv("POSTGRES_DATABASE")
POSTGRES_USER: Optional[str] = os.getenv("POSTGRES_USER") # Entra ID user (e.g., user@tenant.onmicrosoft.com)
POSTGRES_PORT: int = int(os.getenv("POSTGRES_PORT", "5432"))
POSTGRES_SSL_MODE: Optional[str] = os.getenv("POSTGRES_SSL_MODE")
POSTGRES_PASSWORD: Optional[str] = os.getenv("POSTGRES_PASSWORD")
AZURE_DB: bool = os.getenv("AZURE_DB", "false").lower() == "true"
# Legacy: Postgres DSN used for systematic reviews and screening databases (fallback)
POSTGRES_URI: Optional[str] = os.getenv("POSTGRES_URI")

# Databricks settings
DATABRICKS_INSTANCE: str = os.getenv("DATABRICKS_INSTANCE")
Expand All @@ -124,6 +135,9 @@ def convert_max_file_size(cls, v):
REDIRECT_URI: str = os.getenv("REDIRECT_URI")
SSO_LOGIN_URL: str = os.getenv("SSO_LOGIN_URL")

# Entra
USE_ENTRA_AUTH: bool = os.getenv("USE_ENTRA_AUTH", "false").lower() == "true"

class Config:
case_sensitive = True
env_file = ".env"
Expand Down
Loading
Loading