diff --git a/backend/db/database.py b/backend/db/database.py index 800a096..7f5389d 100644 --- a/backend/db/database.py +++ b/backend/db/database.py @@ -137,10 +137,13 @@ async def get_db(): await session.close() -async def get_latest_assessments(db: AsyncSession, control_ids: list[str] = None): +async def get_latest_assessments( + db: AsyncSession, control_ids: list[str] = None, columns: list = None +): """ Shared helper to fetch the latest AssessmentRecord for each control. Optionally filtered by a list of control_ids for better performance. + Optionally fetch only specific columns to reduce ORM overhead. """ sub_q = select( AssessmentRecord.control_id, @@ -152,11 +155,26 @@ async def get_latest_assessments(db: AsyncSession, control_ids: list[str] = None sub_q = sub_q.subquery() - query = select(AssessmentRecord).join( + if columns: + # Clone list to avoid side effects and ensure control_id is present + target_cols = list(columns) + if AssessmentRecord.control_id not in target_cols: + target_cols.append(AssessmentRecord.control_id) + query = select(*target_cols) + else: + query = select(AssessmentRecord) + + query = query.join( sub_q, (AssessmentRecord.control_id == sub_q.c.control_id) & (AssessmentRecord.assessment_date == sub_q.c.max_date), ) result = await db.execute(query) - return {a.control_id: a for a in result.scalars().all()} + + if columns: + # When columns are specified, result.all() returns Row objects + return {a.control_id: a for a in result.all()} + else: + # Default behavior: return full AssessmentRecord objects + return {a.control_id: a for a in result.scalars().all()} diff --git a/backend/routers/assessment.py b/backend/routers/assessment.py index ff7fbe0..a2358f5 100644 --- a/backend/routers/assessment.py +++ b/backend/routers/assessment.py @@ -78,10 +78,13 @@ class SPRSResult(BaseModel): description="Get overall CMMC compliance posture summary including implementation percentages, SPRS score, and breakdown by domain and level.", ) async def get_compliance_dashboard(db: AsyncSession = Depends(get_db)): - result = await db.execute(select(ControlRecord)) - controls = result.scalars().all() + # Performance Optimization: Select only required columns to avoid ORM overhead of large fields + result = await db.execute(select(ControlRecord.id, ControlRecord.domain, ControlRecord.level)) + controls = result.all() - assessments_map = await get_latest_assessments(db) + assessments_map = await get_latest_assessments( + db, columns=[AssessmentRecord.control_id, AssessmentRecord.status] + ) by_domain = {} by_level = { @@ -157,10 +160,13 @@ async def get_compliance_dashboard(db: AsyncSession = Depends(get_db)): description="Calculate the DoD Supplier Performance Risk System (SPRS) score based on current control implementation status. Score ranges from -203 to 110.", ) async def calculate_sprs_score(db: AsyncSession = Depends(get_db)): - result = await db.execute(select(ControlRecord)) - controls = result.scalars().all() + # Performance Optimization: Select only required columns (id is used to check against deductions) + result = await db.execute(select(ControlRecord.id)) + controls = result.all() - assessments_map = await get_latest_assessments(db) + assessments_map = await get_latest_assessments( + db, columns=[AssessmentRecord.control_id, AssessmentRecord.status] + ) sprs = 110 deductions_list = [] diff --git a/backend/routers/reports.py b/backend/routers/reports.py index 5f2660b..c9ca1e7 100644 --- a/backend/routers/reports.py +++ b/backend/routers/reports.py @@ -62,11 +62,22 @@ async def generate_ssp( Generate a NIST SP 800-171 / CMMC 2.0 SSP in Markdown format. Includes: system overview, control family summaries, implementation status. """ - # Fetch latest assessments - assessments_dict = await get_latest_assessments(db) + # Performance Optimization: Select only required columns + assessments_dict = await get_latest_assessments( + db, + columns=[ + AssessmentRecord.control_id, + AssessmentRecord.status, + AssessmentRecord.confidence, + AssessmentRecord.notes, + AssessmentRecord.evidence_ids, + ], + ) assessments = list(assessments_dict.values()) - controls_result = await db.execute(select(ControlRecord)) - controls = {c.id: c for c in controls_result.scalars().all()} + controls_result = await db.execute( + select(ControlRecord.id, ControlRecord.title, ControlRecord.description) + ) + controls = {c.id: c for c in controls_result.all()} # Count by status status_counts = { @@ -196,10 +207,23 @@ async def generate_poam( Generate a Plan of Action & Milestones (POA&M) as CSV. Includes all partial and not_implemented controls. """ - assessments_dict = await get_latest_assessments(db) + # Performance Optimization: Select only required columns + assessments_dict = await get_latest_assessments( + db, + columns=[ + AssessmentRecord.control_id, + AssessmentRecord.status, + AssessmentRecord.confidence, + AssessmentRecord.notes, + AssessmentRecord.next_review, + AssessmentRecord.assessor, + ], + ) assessments = list(assessments_dict.values()) - controls_result = await db.execute(select(ControlRecord)) - controls = {c.id: c for c in controls_result.scalars().all()} + controls_result = await db.execute( + select(ControlRecord.id, ControlRecord.title, ControlRecord.zt_pillar) + ) + controls = {c.id: c for c in controls_result.all()} output = io.StringIO() writer = csv.writer(output) @@ -259,7 +283,10 @@ async def get_dashboard( db: AsyncSession = Depends(get_db), ): """Return compliance posture summary for dashboard rendering.""" - assessments_dict = await get_latest_assessments(db) + # Performance Optimization: Select only required columns + assessments_dict = await get_latest_assessments( + db, columns=[AssessmentRecord.control_id, AssessmentRecord.status] + ) assessments = list(assessments_dict.values()) status_counts = {