diff --git a/backend/db/database.py b/backend/db/database.py index 800a096..ff21dc7 100644 --- a/backend/db/database.py +++ b/backend/db/database.py @@ -137,10 +137,13 @@ async def get_db(): await session.close() -async def get_latest_assessments(db: AsyncSession, control_ids: list[str] = None): +async def get_latest_assessments( + db: AsyncSession, control_ids: list[str] = None, columns: list = None +): """ Shared helper to fetch the latest AssessmentRecord for each control. Optionally filtered by a list of control_ids for better performance. + Supports selective column fetching via the 'columns' parameter. """ sub_q = select( AssessmentRecord.control_id, @@ -152,11 +155,22 @@ async def get_latest_assessments(db: AsyncSession, control_ids: list[str] = None sub_q = sub_q.subquery() - query = select(AssessmentRecord).join( + if columns: + # Ensure control_id is always included for the dictionary mapping + fetch_columns = list(columns) + if AssessmentRecord.control_id not in fetch_columns: + fetch_columns.append(AssessmentRecord.control_id) + query = select(*fetch_columns) + else: + query = select(AssessmentRecord) + + query = query.join( sub_q, (AssessmentRecord.control_id == sub_q.c.control_id) & (AssessmentRecord.assessment_date == sub_q.c.max_date), ) result = await db.execute(query) + if columns: + return {a.control_id: a for a in result.all()} return {a.control_id: a for a in result.scalars().all()} diff --git a/backend/routers/assessment.py b/backend/routers/assessment.py index ff7fbe0..bd7cb2a 100644 --- a/backend/routers/assessment.py +++ b/backend/routers/assessment.py @@ -78,10 +78,21 @@ class SPRSResult(BaseModel): description="Get overall CMMC compliance posture summary including implementation percentages, SPRS score, and breakdown by domain and level.", ) async def get_compliance_dashboard(db: AsyncSession = Depends(get_db)): - result = await db.execute(select(ControlRecord)) - controls = result.scalars().all() + # Optimized: Select only required columns + result = await db.execute( + select( + ControlRecord.id, + ControlRecord.domain, + ControlRecord.level, + ControlRecord.score_value, + ) + ) + controls = result.all() - assessments_map = await get_latest_assessments(db) + # Optimized: Select only required columns from assessments + assessments_map = await get_latest_assessments( + db, columns=[AssessmentRecord.control_id, AssessmentRecord.status] + ) by_domain = {} by_level = { @@ -157,10 +168,19 @@ async def get_compliance_dashboard(db: AsyncSession = Depends(get_db)): description="Calculate the DoD Supplier Performance Risk System (SPRS) score based on current control implementation status. Score ranges from -203 to 110.", ) async def calculate_sprs_score(db: AsyncSession = Depends(get_db)): - result = await db.execute(select(ControlRecord)) - controls = result.scalars().all() + # Optimized: Select only required columns + result = await db.execute( + select( + ControlRecord.id, + ControlRecord.score_value, + ) + ) + controls = result.all() - assessments_map = await get_latest_assessments(db) + # Optimized: Select only required columns from assessments + assessments_map = await get_latest_assessments( + db, columns=[AssessmentRecord.control_id, AssessmentRecord.status] + ) sprs = 110 deductions_list = [] diff --git a/backend/routers/reports.py b/backend/routers/reports.py index 5f2660b..f256051 100644 --- a/backend/routers/reports.py +++ b/backend/routers/reports.py @@ -62,11 +62,22 @@ async def generate_ssp( Generate a NIST SP 800-171 / CMMC 2.0 SSP in Markdown format. Includes: system overview, control family summaries, implementation status. """ - # Fetch latest assessments - assessments_dict = await get_latest_assessments(db) + # Optimized: Fetch only required columns for SSP summary + assessments_dict = await get_latest_assessments( + db, + columns=[ + AssessmentRecord.control_id, + AssessmentRecord.status, + AssessmentRecord.confidence, + AssessmentRecord.notes, + AssessmentRecord.evidence_ids, + ], + ) assessments = list(assessments_dict.values()) - controls_result = await db.execute(select(ControlRecord)) - controls = {c.id: c for c in controls_result.scalars().all()} + controls_result = await db.execute( + select(ControlRecord.id, ControlRecord.title, ControlRecord.score_value) + ) + controls = {c.id: c for c in controls_result.all()} # Count by status status_counts = { @@ -196,10 +207,23 @@ async def generate_poam( Generate a Plan of Action & Milestones (POA&M) as CSV. Includes all partial and not_implemented controls. """ - assessments_dict = await get_latest_assessments(db) + # Optimized: Fetch only required columns for POAM + assessments_dict = await get_latest_assessments( + db, + columns=[ + AssessmentRecord.control_id, + AssessmentRecord.status, + AssessmentRecord.confidence, + AssessmentRecord.next_review, + AssessmentRecord.assessor, + AssessmentRecord.notes, + ], + ) assessments = list(assessments_dict.values()) - controls_result = await db.execute(select(ControlRecord)) - controls = {c.id: c for c in controls_result.scalars().all()} + controls_result = await db.execute( + select(ControlRecord.id, ControlRecord.title, ControlRecord.zt_pillar) + ) + controls = {c.id: c for c in controls_result.all()} output = io.StringIO() writer = csv.writer(output) @@ -259,7 +283,10 @@ async def get_dashboard( db: AsyncSession = Depends(get_db), ): """Return compliance posture summary for dashboard rendering.""" - assessments_dict = await get_latest_assessments(db) + # Optimized: Fetch only required columns for dashboard counts + assessments_dict = await get_latest_assessments( + db, columns=[AssessmentRecord.control_id, AssessmentRecord.status] + ) assessments = list(assessments_dict.values()) status_counts = {