diff --git a/src/llm.py b/src/llm.py index 70937f9..d9ff3ce 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,135 +1,113 @@ import json +import logging import os -import requests +import requests -class LLM: - def __init__(self, transcript_text=None, target_fields=None, json=None): - if json is None: - json = {} - self._transcript_text = transcript_text # str - self._target_fields = target_fields # List, contains the template field. - self._json = json # dictionary - - def type_check_all(self): - if type(self._transcript_text) is not str: - raise TypeError( - f"ERROR in LLM() attributes ->\ - Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}" - ) - elif type(self._target_fields) is not list: - raise TypeError( - f"ERROR in LLM() attributes ->\ - Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" - ) +from src.schemas.incident_report import IncidentReport - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ - prompt = f""" - SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". - --- - DATA: - Target JSON field to find in text: {current_field} - - TEXT: {self._transcript_text} - """ - - return prompt - - def main_loop(self): - # self.type_check_all() - for field in self._target_fields.keys(): - prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") - ollama_url = f"{ollama_host}/api/generate" - - payload = { - "model": "mistral", - "prompt": prompt, - "stream": False, # don't really know why --> look into this later. - } - - try: - response = requests.post(ollama_url, json=payload) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." - ) - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Ollama returned an error: {e}") - - # parse response - json_data = response.json() - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) - - print("----------------------------------") - print("\t[LOG] Resulting JSON created from the input text:") - print(json.dumps(self._json, indent=2)) - print("--------- extracted data ---------") +logger = logging.getLogger(__name__) - return self - def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict +class LLM: + def __init__(self, transcript_text=None, target_fields=None): + self._transcript_text = transcript_text + # target_fields kept for backward compatibility with FileManipulator; + # the new extraction uses IncidentReport.llm_schema_hint() instead. + self._target_fields = target_fields + self._result: IncidentReport | None = None + + def build_prompt(self) -> str: """ - value = value.strip().replace('"', "") - parsed_value = None + Build a single structured prompt for Ollama's /api/generate endpoint. - if value != "-1": - parsed_value = value + Uses IncidentReport.llm_schema_hint() as the schema constraint so the + model always returns a JSON object with the correct field names. + Mistral instruction format ([INST] … [/INST]) is used so the system + context is respected. + """ + schema_hint = json.dumps(IncidentReport.llm_schema_hint(), indent=2) + + system = ( + "You are an AI assistant that extracts structured incident report data " + "from transcribed voice recordings made by first responders.\n\n" + "Return ONLY a valid JSON object that matches the schema below. " + "Do not include any explanation, markdown code fences, or extra text.\n\n" + f"Schema:\n{schema_hint}\n\n" + "Rules:\n" + " - Set a field to null if the information is not present in the transcript.\n" + " - For list fields (e.g. unit_ids, personnel), return a JSON array.\n" + " - For timestamps, reproduce the exact phrasing from the transcript " + "(e.g. '14:35', '1435 hours', '2:35 PM').\n" + " - Do not invent or infer values that are not explicitly stated." + ) - if ";" in value: - parsed_value = self.handle_plural_values(value) + user = ( + "Extract all incident report fields from the following transcript:\n\n" + f"{self._transcript_text}" + ) - if field in self._json.keys(): - self._json[field].append(parsed_value) - else: - self._json[field] = parsed_value + return f"[INST] {system}\n\n{user} [/INST]" - return + # Ollama request + def _ollama_url(self) -> str: + host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + return f"{host}/api/generate" - def handle_plural_values(self, plural_value): + def main_loop(self) -> "LLM": """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] + Send a single structured request to Ollama and parse the response as an + IncidentReport. Replaces the old per-field loop. + + On success, self._result holds the validated IncidentReport. + Fields that could not be extracted are listed in result.requires_review. """ - if ";" not in plural_value: - raise ValueError( - f"Value is not plural, doesn't have ; separator, Value: {plural_value}" + prompt = self.build_prompt() + model = os.getenv("OLLAMA_MODEL", "mistral") + + payload = { + "model": model, + "prompt": prompt, + "format": "json", # constrains Ollama output to valid JSON + "stream": False, + } + + try: + response = requests.post(self._ollama_url(), json=payload) + response.raise_for_status() + except requests.exceptions.ConnectionError: + raise ConnectionError( + f"Could not connect to Ollama at {self._ollama_url()}. " + "Please ensure Ollama is running and accessible." ) + except requests.exceptions.HTTPError as e: + raise RuntimeError(f"Ollama returned an error: {e}") from e - print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." - ) - values = plural_value.split(";") + raw = response.json()["response"] - # Remove trailing leading whitespace - for i in range(len(values)): - current = i + 1 - if current < len(values): - clean_value = values[current].lstrip() - values[current] = clean_value + try: + extracted = json.loads(raw) + except json.JSONDecodeError as e: + logger.warning("Ollama returned invalid JSON: %s. Raw (first 200 chars): %.200s", e, raw) + extracted = {} - print(f"\t[LOG]: Resulting formatted list of values: {values}") + # Pydantic validates and auto-populates requires_review for missing fields + self._result = IncidentReport(**extracted) - return values + logger.info("Extraction complete. requires_review: %s", self._result.requires_review) + + return self + + # Data accessors + def get_data(self) -> dict: + """ + Return extracted data as a plain dict for use by Filler. + requires_review is excluded here — use get_report() to access it. + """ + if self._result is None: + return {} + return self._result.model_dump(exclude={"requires_review"}) - def get_data(self): - return self._json + def get_report(self) -> IncidentReport | None: + """Return the full IncidentReport including requires_review.""" + return self._result diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000..165dd7a --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1,3 @@ +from src.schemas.incident_report import IncidentReport + +__all__ = ["IncidentReport"] diff --git a/src/schemas/incident_report.py b/src/schemas/incident_report.py new file mode 100644 index 0000000..6cec1c6 --- /dev/null +++ b/src/schemas/incident_report.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field +from pydantic import model_validator +from typing import Optional + + +class IncidentReport(BaseModel): + """ + Canonical schema for a FireForm incident report. + + All fields are Optional. Fields that could not be reliably extracted by the + LLM are recorded in `requires_review` so responders can spot what needs + manual completion before the PDF is submitted. + + This schema serves three roles: + 1. LLM target — the JSON structure the extraction prompt asks for. + 2. Validator — Pydantic validates the LLM response against this model. + 3. Mapper source — TemplateMapper resolves json_path values from this model + to populate PDF form fields. + + Note: `requires_review` is populated by the extraction pipeline, never by + the LLM itself. Exclude it when building the schema hint sent to the model. + """ + + # ------------------------------------------------------------------------- + # Identity + # ------------------------------------------------------------------------- + + incident_id: Optional[str] = Field( + default=None, + description="Unique identifier for this incident assigned by the agency.", + examples=["CAL-2024-001", "INC-20240101-005"], + ) + agency_id: Optional[str] = Field( + default=None, + description="Name or identifier of the responding agency.", + examples=["CAL FIRE", "SFFD", "LAFD"], + ) + unit_ids: Optional[list[str]] = Field( + default=None, + description=( + "List of unit or apparatus identifiers responding to the incident. " + "If multiple units are named, list all of them." + ), + examples=[["Engine 12", "Truck 4"], ["Unit 7", "Medic 3"]], + ) + + # ------------------------------------------------------------------------- + # Incident classification + # ------------------------------------------------------------------------- + + incident_type: Optional[str] = Field( + default=None, + description=( + "Type or category of the incident. Use a short lowercase label. " + "Common values: wildfire, structure_fire, vehicle_accident, EMS, hazmat, rescue." + ), + examples=["wildfire", "structure_fire", "EMS", "vehicle_accident", "hazmat"], + ) + incident_severity: Optional[str] = Field( + default=None, + description="Severity or alarm level of the incident.", + examples=["low", "moderate", "high", "major"], + ) + + # Location + + location_address: Optional[str] = Field( + default=None, + description="Street address or nearest landmark address of the incident.", + examples=["1234 Oak Street", "Highway 1 near Mile Marker 42"], + ) + location_city: Optional[str] = Field( + default=None, + description="City where the incident occurred.", + examples=["San Francisco", "Half Moon Bay"], + ) + location_county: Optional[str] = Field( + default=None, + description="County where the incident occurred.", + examples=["San Mateo County", "Los Angeles County"], + ) + location_state: Optional[str] = Field( + default=None, + description="State abbreviation or full name where the incident occurred.", + examples=["CA", "California"], + ) + location_coordinates: Optional[str] = Field( + default=None, + description="GPS coordinates in decimal degrees format: 'latitude, longitude'.", + examples=["37.7749, -122.4194"], + ) + + # Timestamps (stored as strings — the LLM extracts them as spoken/written) + alarm_time: Optional[str] = Field( + default=None, + description=( + "Time the alarm was received or the incident was reported. " + "Use HH:MM (24-hour) if possible, or reproduce the exact phrasing." + ), + examples=["14:35", "2:35 PM", "1435 hours"], + ) + dispatch_time: Optional[str] = Field( + default=None, + description="Time units were dispatched to the incident.", + examples=["14:37", "1437 hours"], + ) + arrival_time: Optional[str] = Field( + default=None, + description="Time the first unit arrived on scene.", + examples=["14:45", "1445 hours"], + ) + controlled_time: Optional[str] = Field( + default=None, + description="Time the incident was declared under control.", + examples=["17:20", "1720 hours"], + ) + clear_time: Optional[str] = Field( + default=None, + description="Time all units cleared the scene.", + examples=["18:05", "1805 hours"], + ) + incident_date: Optional[str] = Field( + default=None, + description="Date the incident occurred.", + examples=["2024-01-01", "January 1, 2024", "01/01/2024"], + ) + + # Personnel + supervisor: Optional[str] = Field( + default=None, + description=( + "Name or identifier of the incident commander or supervising officer " + "in charge at the scene." + ), + examples=["Battalion Chief Johnson", "Captain Maria Torres"], + ) + personnel: Optional[list[str]] = Field( + default=None, + description=( + "Names or identifiers of all personnel assigned to the incident. " + "Include rank or role if mentioned." + ), + examples=[["FF Smith", "FF Jones", "Paramedic Lee"]], + ) + personnel_count: Optional[int] = Field( + default=None, + description="Total number of personnel responding to the incident.", + examples=[6, 12], + ) + + # Casualties & medical (EMS / law enforcement forms) + casualties: Optional[int] = Field( + default=None, + description="Total number of injured persons (civilians and/or responders).", + examples=[2, 0], + ) + fatalities: Optional[int] = Field( + default=None, + description="Total number of fatalities resulting from the incident.", + examples=[0, 1], + ) + patients_transported: Optional[int] = Field( + default=None, + description="Number of patients transported to a medical facility.", + examples=[1, 3], + ) + hospital_destination: Optional[str] = Field( + default=None, + description="Name or identifier of the hospital patients were transported to.", + examples=["SF General Hospital", "UCLA Medical Center"], + ) + patient_condition: Optional[str] = Field( + default=None, + description="Reported condition of the patient(s) at time of transport or care.", + examples=["stable", "critical", "deceased"], + ) + + # Narrative + narrative: Optional[str] = Field( + default=None, + description=( + "Free-text description of the incident: what happened, actions taken " + "by responders, and the outcome. Reproduce faithfully from the transcript." + ), + examples=[ + "Crews arrived to find a one-story structure with heavy smoke showing " + "from the roof. A primary search revealed no occupants. Fire was " + "knocked down within 20 minutes. Cause determined to be electrical." + ], + ) + + # Wildfire-specific (Cal Fire FIRESCOPE) + area_burned_acres: Optional[float] = Field( + default=None, + description="Estimated area burned in acres. Wildfire incidents only.", + examples=[12.5, 0.25], + ) + structures_threatened: Optional[int] = Field( + default=None, + description="Number of structures threatened by the incident.", + examples=[15, 0], + ) + structures_destroyed: Optional[int] = Field( + default=None, + description="Number of structures destroyed.", + examples=[3, 0], + ) + containment_percent: Optional[int] = Field( + default=None, + ge=0, + le=100, + description="Percentage of wildfire perimeter that is contained (0–100).", + examples=[25, 100], + ) + fuel_type: Optional[str] = Field( + default=None, + description="Predominant fuel type involved in the wildfire.", + examples=["grass", "chaparral", "timber", "brush"], + ) + + # Law enforcement (incident/crime report forms) + case_number: Optional[str] = Field( + default=None, + description="Law enforcement case or report number.", + examples=["SFPD-2024-00123"], + ) + officer_id: Optional[str] = Field( + default=None, + description="Badge number or identifier of the reporting officer.", + examples=["Badge #4821"], + ) + suspect_description: Optional[str] = Field( + default=None, + description="Physical description or identifying information of any suspect.", + examples=["Male, approximately 30 years old, 6ft, wearing a red jacket"], + ) + + # Pipeline metadata — populated by the extraction pipeline, NOT the LLM + requires_review: list[str] = Field( + default_factory=list, + description=( + "Field names that could not be reliably extracted and require human " + "review before the PDF is submitted. Populated by the retry loop, " + "never by the LLM." + ), + ) + + @model_validator(mode="after") + def collect_missing_fields(self) -> IncidentReport: + """ + After model construction, populate `requires_review` with the names of + any core fields that are still None. This runs on every instantiation, + so it also catches fields that were validly set to None after a retry + loop exhausts all attempts. + + Only core operational fields are checked — metadata-only or + incident-type-specific fields (wildfire, law enforcement) are skipped + here; the TemplateMapper handles conditional field requirements. + """ + CORE_FIELDS = { + "incident_type", + "location_address", + "location_city", + "alarm_time", + "incident_date", + "supervisor", + "narrative", + } + missing = [ + field + for field in CORE_FIELDS + if getattr(self, field) is None + ] + # Only overwrite if the caller hasn't already populated requires_review + if not self.requires_review: + self.requires_review = missing + return self + + @classmethod + def llm_schema_hint(cls) -> dict: + """ + Return a JSON-serialisable dict describing each field that the LLM + should extract. Excludes `requires_review` (pipeline-only) and + incident-type-specific fields that may not apply to every call. + + Used to build the structured system prompt in LLM.build_prompt(). + """ + schema = cls.model_json_schema() + excluded = {"requires_review"} + properties = { + k: v + for k, v in schema.get("properties", {}).items() + if k not in excluded + } + return { + "type": "object", + "properties": properties, + "required": [], # all fields are optional — missing → requires_review + } diff --git a/tests/test_forms.py b/tests/test_forms.py index 8f432bf..21de801 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1,25 +1,36 @@ def test_submit_form(client): - pass - # First create a template - # form_payload = { - # "template_id": 3, - # "input_text": "Hi. The employee's name is John Doe. His job title is managing director. His department supervisor is Jane Doe. His phone number is 123456. His email is jdoe@ucsc.edu. The signature is , and the date is 01/02/2005", - # } + from unittest.mock import patch, MagicMock - # template_res = client.post("/templates/", json=template_payload) - # template_id = template_res.json()["id"] + # Step 1: create a template to get a valid template_id + template_payload = { + "name": "Test Form", + "pdf_path": "src/inputs/file.pdf", + "fields": { + "name": "string", + "date": "string", + }, + } + template_res = client.post("/templates/create", json=template_payload) + assert template_res.status_code == 200 + template_id = template_res.json()["id"] - # # Submit a form - # form_payload = { - # "template_id": template_id, - # "data": {"rating": 5, "comment": "Great service"}, - # } + # Step 2: submit a form fill request — mock Controller to avoid hitting Ollama + fake_output_path = "src/outputs/file_20240101_120000_filled.pdf" + form_payload = { + "template_id": template_id, + "input_text": "Employee name is John Doe. Date is 01/01/2024.", + } - # response = client.post("/forms/", json=form_payload) + with patch("api.routes.forms.Controller") as MockController: + mock_ctrl = MagicMock() + mock_ctrl.fill_form.return_value = fake_output_path + MockController.return_value = mock_ctrl - # assert response.status_code == 200 + response = client.post("/forms/fill", json=form_payload) - # data = response.json() - # assert data["id"] is not None - # assert data["template_id"] == template_id - # assert data["data"] == form_payload["data"] + assert response.status_code == 200 + data = response.json() + assert data["id"] is not None + assert data["template_id"] == template_id + assert data["input_text"] == form_payload["input_text"] + assert data["output_pdf_path"] == fake_output_path