diff --git a/api/main.py b/api/main.py index d0b8c79..bbebbc6 100644 --- a/api/main.py +++ b/api/main.py @@ -1,7 +1,9 @@ from fastapi import FastAPI from api.routes import templates, forms +from api.routes import transcribe app = FastAPI() app.include_router(templates.router) -app.include_router(forms.router) \ No newline at end of file +app.include_router(forms.router) +app.include_router(transcribe.router) \ No newline at end of file diff --git a/api/routes/transcribe.py b/api/routes/transcribe.py new file mode 100644 index 0000000..54cf027 --- /dev/null +++ b/api/routes/transcribe.py @@ -0,0 +1,53 @@ +from pathlib import Path + +from fastapi import APIRouter, File, HTTPException, UploadFile + +from api.schemas.transcribe import TranscribeResponse +from src.transcriber import SUPPORTED_FORMATS, Transcriber + +router = APIRouter(prefix="/transcribe", tags=["transcription"]) + +# Module-level singleton — Whisper model is lazy-loaded on first request. +_transcriber: Transcriber | None = None + + +def _get_transcriber() -> Transcriber: + global _transcriber + if _transcriber is None: + _transcriber = Transcriber() + return _transcriber + + +@router.post("", response_model=TranscribeResponse) +async def transcribe_audio(file: UploadFile = File(...)): + """ + Upload an audio file and receive a plain-text transcription. + + - Accepted formats: WAV, MP3, M4A, MP4, OGG, FLAC + - All transcription runs locally via Whisper — no data leaves the machine. + - Model size is configured via the `WHISPER_MODEL` environment variable + (default: `base`). Valid values: `tiny`, `base`, `small`, `medium`, `large`. + """ + suffix = Path(file.filename).suffix.lower() + if suffix not in SUPPORTED_FORMATS: + raise HTTPException( + status_code=415, + detail=( + f"Unsupported audio format {suffix!r}. " + f"Accepted: {sorted(SUPPORTED_FORMATS)}" + ), + ) + + contents = await file.read() + transcriber = _get_transcriber() + + try: + text = transcriber.transcribe_bytes(contents, suffix=suffix) + except Exception as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + + return TranscribeResponse( + text=text, + model_used=transcriber.model_size, + audio_filename=file.filename, + ) diff --git a/api/schemas/transcribe.py b/api/schemas/transcribe.py new file mode 100644 index 0000000..b6bd259 --- /dev/null +++ b/api/schemas/transcribe.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class TranscribeResponse(BaseModel): + text: str + model_used: str + audio_filename: str diff --git a/requirements.txt b/requirements.txt index eaa6c81..c1aa9d6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,7 @@ sqlmodel pytest httpx numpy<2 -ollama \ No newline at end of file +ollama +pyyaml +openai-whisper +python-multipart \ No newline at end of file diff --git a/src/file_manipulator.py b/src/file_manipulator.py index b7815cc..74788a3 100644 --- a/src/file_manipulator.py +++ b/src/file_manipulator.py @@ -1,7 +1,13 @@ +import logging import os + +from commonforms import prepare_form + from src.filler import Filler from src.llm import LLM -from commonforms import prepare_form +from src.template_mapper import TemplateMapper + +logger = logging.getLogger(__name__) class FileManipulator: @@ -9,39 +15,97 @@ def __init__(self): self.filler = Filler() self.llm = LLM() - def create_template(self, pdf_path: str): + def create_template(self, pdf_path: str) -> str: """ - By using commonforms, we create an editable .pdf template and we store it. + Prepare a fillable PDF template using commonforms and return its path. """ template_path = pdf_path[:-4] + "_template.pdf" prepare_form(pdf_path, template_path) return template_path - def fill_form(self, user_input: str, fields: list, pdf_form_path: str): + def fill_form( + self, + user_input: str, + fields: dict, + pdf_form_path: str, + yaml_path: str | None = None, + ) -> str: """ - It receives the raw data, runs the PDF filling logic, - and returns the path to the newly created file. + Extract data from user_input and fill pdf_form_path. + + When yaml_path is provided and the file exists, the new pipeline is used: + LLM → IncidentReport → TemplateMapper → Filler (named fields) + + When yaml_path is absent, falls back to the legacy pipeline: + LLM → raw dict → Filler (positional fields) + + Returns the path to the filled PDF. """ - print("[1] Received request from frontend.") - print(f"[2] PDF template path: {pdf_form_path}") + logger.info("Received fill request. PDF: %s YAML: %s", pdf_form_path, yaml_path) if not os.path.exists(pdf_form_path): - print(f"Error: PDF template not found at {pdf_form_path}") - return None # Or raise an exception + raise FileNotFoundError(f"PDF template not found at {pdf_form_path}") + + self.llm._transcript_text = user_input + + if yaml_path and os.path.exists(yaml_path): + return self._fill_with_mapper(yaml_path) + + logger.warning( + "No YAML template provided or found at %r — using legacy positional mapping.", + yaml_path, + ) + return self._fill_legacy(fields, pdf_form_path) + + # ------------------------------------------------------------------------- + # New pipeline: LLM → IncidentReport → TemplateMapper → Filler + # ------------------------------------------------------------------------- + + def _fill_with_mapper(self, yaml_path: str) -> str: + mapper = TemplateMapper(yaml_path) + + self.llm.main_loop() + report = self.llm.get_report() + + if report and report.requires_review: + logger.warning( + "Extraction incomplete — the following fields require manual review: %s", + report.requires_review, + ) + + field_values = mapper.resolve(report) + return self.filler.fill_form(pdf_form=mapper.pdf_path, field_values=field_values) + + # ------------------------------------------------------------------------- + # Legacy pipeline: kept for backward compatibility until all templates have + # YAML mappings (Phase 2, Week 5). + # ------------------------------------------------------------------------- + + def _fill_legacy(self, fields: dict, pdf_form_path: str) -> str: + self.llm._target_fields = fields + self.llm.main_loop() + data = self.llm.get_data() - print("[3] Starting extraction and PDF filling process...") - try: - self.llm._target_fields = fields - self.llm._transcript_text = user_input - output_name = self.filler.fill_form(pdf_form=pdf_form_path, llm=self.llm) + # Build a positional {field_name: value} dict from the PDF's own field names + # and the extracted values in visual order — brittle, but preserved until + # YAML templates cover all forms. + from pdfrw import PdfReader - print("\n----------------------------------") - print("✅ Process Complete.") - print(f"Output saved to: {output_name}") + pdf = PdfReader(pdf_form_path) + pdf_fields = [] + for page in pdf.pages: + if page.Annots: + sorted_annots = sorted( + page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) + ) + for annot in sorted_annots: + if annot.Subtype == "/Widget" and annot.T: + pdf_fields.append(self.filler._field_name(annot.T)) - return output_name + values = [v for v in data.values() if v is not None] + field_values = { + pdf_fields[i]: str(values[i]) + for i in range(min(len(pdf_fields), len(values))) + } - except Exception as e: - print(f"An error occurred during PDF generation: {e}") - # Re-raise the exception so the frontend can handle it - raise e + return self.filler.fill_form(pdf_form=pdf_form_path, field_values=field_values) diff --git a/src/filler.py b/src/filler.py index e31e535..623ddc8 100644 --- a/src/filler.py +++ b/src/filler.py @@ -1,16 +1,35 @@ -from pdfrw import PdfReader, PdfWriter -from src.llm import LLM +from __future__ import annotations + +import logging from datetime import datetime +from typing import Any + +from pdfrw import PdfReader, PdfWriter + +logger = logging.getLogger(__name__) class Filler: - def __init__(self): - pass + """ + Fills a PDF form using a named field mapping produced by TemplateMapper. + + Replaces the old positional approach (answers_list[i]) with an explicit + {pdf_field_name: value} dict so every value lands in the correct field + regardless of visual order or page layout. + """ - def fill_form(self, pdf_form: str, llm: LLM): + def fill_form(self, pdf_form: str, field_values: dict[str, Any]) -> str: """ - Fill a PDF form with values from user_input using LLM. - Fields are filled in the visual order (top-to-bottom, left-to-right). + Write field_values into the PDF at pdf_form and save to a timestamped path. + + Parameters + ---------- + pdf_form: Path to the fillable PDF template. + field_values: {pdf_field_name: value} produced by TemplateMapper.resolve(). + + Returns + ------- + Path to the newly written filled PDF. """ output_pdf = ( pdf_form[:-4] @@ -19,34 +38,37 @@ def fill_form(self, pdf_form: str, llm: LLM): + "_filled.pdf" ) - # Generate dictionary of answers from your original function - t2j = llm.main_loop() - textbox_answers = t2j.get_data() # This is a dictionary - - answers_list = list(textbox_answers.values()) - - # Read PDF pdf = PdfReader(pdf_form) + filled_count = 0 - # Loop through pages for page in pdf.pages: - if page.Annots: - sorted_annots = sorted( - page.Annots, key=lambda a: (-float(a.Rect[1]), float(a.Rect[0])) - ) - - i = 0 - for annot in sorted_annots: - if annot.Subtype == "/Widget" and annot.T: - if i < len(answers_list): - annot.V = f"{answers_list[i]}" - annot.AP = None - i += 1 - else: - # Stop if we run out of answers - break + if not page.Annots: + continue + for annot in page.Annots: + if annot.Subtype != "/Widget" or not annot.T: + continue - PdfWriter().write(output_pdf, pdf) + field_name = self._field_name(annot.T) + if field_name in field_values: + annot.V = str(field_values[field_name]) + annot.AP = None + filled_count += 1 + else: + logger.debug("PDF field %r has no mapped value — left blank", field_name) - # Your main.py expects this function to return the path + logger.info("Filled %d / %d mapped fields in %s", filled_count, len(field_values), pdf_form) + PdfWriter().write(output_pdf, pdf) return output_pdf + + # ------------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------------- + + @staticmethod + def _field_name(annot_t) -> str: + """ + Extract the plain field name string from a pdfrw PdfString. + pdfrw wraps PDF literal strings in parentheses, e.g. '(FieldName)'. + """ + raw = str(annot_t) + return raw[1:-1] if raw.startswith("(") and raw.endswith(")") else raw diff --git a/src/llm.py b/src/llm.py index 70937f9..d9ff3ce 100644 --- a/src/llm.py +++ b/src/llm.py @@ -1,135 +1,113 @@ import json +import logging import os -import requests +import requests -class LLM: - def __init__(self, transcript_text=None, target_fields=None, json=None): - if json is None: - json = {} - self._transcript_text = transcript_text # str - self._target_fields = target_fields # List, contains the template field. - self._json = json # dictionary - - def type_check_all(self): - if type(self._transcript_text) is not str: - raise TypeError( - f"ERROR in LLM() attributes ->\ - Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}" - ) - elif type(self._target_fields) is not list: - raise TypeError( - f"ERROR in LLM() attributes ->\ - Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}" - ) +from src.schemas.incident_report import IncidentReport - def build_prompt(self, current_field): - """ - This method is in charge of the prompt engineering. It creates a specific prompt for each target field. - @params: current_field -> represents the current element of the json that is being prompted. - """ - prompt = f""" - SYSTEM PROMPT: - You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings. - You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return - only a single string containing the identified value for the JSON field. - If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";". - If you don't identify the value in the provided text, return "-1". - --- - DATA: - Target JSON field to find in text: {current_field} - - TEXT: {self._transcript_text} - """ - - return prompt - - def main_loop(self): - # self.type_check_all() - for field in self._target_fields.keys(): - prompt = self.build_prompt(field) - # print(prompt) - # ollama_url = "http://localhost:11434/api/generate" - ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") - ollama_url = f"{ollama_host}/api/generate" - - payload = { - "model": "mistral", - "prompt": prompt, - "stream": False, # don't really know why --> look into this later. - } - - try: - response = requests.post(ollama_url, json=payload) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise ConnectionError( - f"Could not connect to Ollama at {ollama_url}. " - "Please ensure Ollama is running and accessible." - ) - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Ollama returned an error: {e}") - - # parse response - json_data = response.json() - parsed_response = json_data["response"] - # print(parsed_response) - self.add_response_to_json(field, parsed_response) - - print("----------------------------------") - print("\t[LOG] Resulting JSON created from the input text:") - print(json.dumps(self._json, indent=2)) - print("--------- extracted data ---------") +logger = logging.getLogger(__name__) - return self - def add_response_to_json(self, field, value): - """ - this method adds the following value under the specified field, - or under a new field if the field doesn't exist, to the json dict +class LLM: + def __init__(self, transcript_text=None, target_fields=None): + self._transcript_text = transcript_text + # target_fields kept for backward compatibility with FileManipulator; + # the new extraction uses IncidentReport.llm_schema_hint() instead. + self._target_fields = target_fields + self._result: IncidentReport | None = None + + def build_prompt(self) -> str: """ - value = value.strip().replace('"', "") - parsed_value = None + Build a single structured prompt for Ollama's /api/generate endpoint. - if value != "-1": - parsed_value = value + Uses IncidentReport.llm_schema_hint() as the schema constraint so the + model always returns a JSON object with the correct field names. + Mistral instruction format ([INST] … [/INST]) is used so the system + context is respected. + """ + schema_hint = json.dumps(IncidentReport.llm_schema_hint(), indent=2) + + system = ( + "You are an AI assistant that extracts structured incident report data " + "from transcribed voice recordings made by first responders.\n\n" + "Return ONLY a valid JSON object that matches the schema below. " + "Do not include any explanation, markdown code fences, or extra text.\n\n" + f"Schema:\n{schema_hint}\n\n" + "Rules:\n" + " - Set a field to null if the information is not present in the transcript.\n" + " - For list fields (e.g. unit_ids, personnel), return a JSON array.\n" + " - For timestamps, reproduce the exact phrasing from the transcript " + "(e.g. '14:35', '1435 hours', '2:35 PM').\n" + " - Do not invent or infer values that are not explicitly stated." + ) - if ";" in value: - parsed_value = self.handle_plural_values(value) + user = ( + "Extract all incident report fields from the following transcript:\n\n" + f"{self._transcript_text}" + ) - if field in self._json.keys(): - self._json[field].append(parsed_value) - else: - self._json[field] = parsed_value + return f"[INST] {system}\n\n{user} [/INST]" - return + # Ollama request + def _ollama_url(self) -> str: + host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/") + return f"{host}/api/generate" - def handle_plural_values(self, plural_value): + def main_loop(self) -> "LLM": """ - This method handles plural values. - Takes in strings of the form 'value1; value2; value3; ...; valueN' - returns a list with the respective values -> [value1, value2, value3, ..., valueN] + Send a single structured request to Ollama and parse the response as an + IncidentReport. Replaces the old per-field loop. + + On success, self._result holds the validated IncidentReport. + Fields that could not be extracted are listed in result.requires_review. """ - if ";" not in plural_value: - raise ValueError( - f"Value is not plural, doesn't have ; separator, Value: {plural_value}" + prompt = self.build_prompt() + model = os.getenv("OLLAMA_MODEL", "mistral") + + payload = { + "model": model, + "prompt": prompt, + "format": "json", # constrains Ollama output to valid JSON + "stream": False, + } + + try: + response = requests.post(self._ollama_url(), json=payload) + response.raise_for_status() + except requests.exceptions.ConnectionError: + raise ConnectionError( + f"Could not connect to Ollama at {self._ollama_url()}. " + "Please ensure Ollama is running and accessible." ) + except requests.exceptions.HTTPError as e: + raise RuntimeError(f"Ollama returned an error: {e}") from e - print( - f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..." - ) - values = plural_value.split(";") + raw = response.json()["response"] - # Remove trailing leading whitespace - for i in range(len(values)): - current = i + 1 - if current < len(values): - clean_value = values[current].lstrip() - values[current] = clean_value + try: + extracted = json.loads(raw) + except json.JSONDecodeError as e: + logger.warning("Ollama returned invalid JSON: %s. Raw (first 200 chars): %.200s", e, raw) + extracted = {} - print(f"\t[LOG]: Resulting formatted list of values: {values}") + # Pydantic validates and auto-populates requires_review for missing fields + self._result = IncidentReport(**extracted) - return values + logger.info("Extraction complete. requires_review: %s", self._result.requires_review) + + return self + + # Data accessors + def get_data(self) -> dict: + """ + Return extracted data as a plain dict for use by Filler. + requires_review is excluded here — use get_report() to access it. + """ + if self._result is None: + return {} + return self._result.model_dump(exclude={"requires_review"}) - def get_data(self): - return self._json + def get_report(self) -> IncidentReport | None: + """Return the full IncidentReport including requires_review.""" + return self._result diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000..165dd7a --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1,3 @@ +from src.schemas.incident_report import IncidentReport + +__all__ = ["IncidentReport"] diff --git a/src/schemas/incident_report.py b/src/schemas/incident_report.py new file mode 100644 index 0000000..6cec1c6 --- /dev/null +++ b/src/schemas/incident_report.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field +from pydantic import model_validator +from typing import Optional + + +class IncidentReport(BaseModel): + """ + Canonical schema for a FireForm incident report. + + All fields are Optional. Fields that could not be reliably extracted by the + LLM are recorded in `requires_review` so responders can spot what needs + manual completion before the PDF is submitted. + + This schema serves three roles: + 1. LLM target — the JSON structure the extraction prompt asks for. + 2. Validator — Pydantic validates the LLM response against this model. + 3. Mapper source — TemplateMapper resolves json_path values from this model + to populate PDF form fields. + + Note: `requires_review` is populated by the extraction pipeline, never by + the LLM itself. Exclude it when building the schema hint sent to the model. + """ + + # ------------------------------------------------------------------------- + # Identity + # ------------------------------------------------------------------------- + + incident_id: Optional[str] = Field( + default=None, + description="Unique identifier for this incident assigned by the agency.", + examples=["CAL-2024-001", "INC-20240101-005"], + ) + agency_id: Optional[str] = Field( + default=None, + description="Name or identifier of the responding agency.", + examples=["CAL FIRE", "SFFD", "LAFD"], + ) + unit_ids: Optional[list[str]] = Field( + default=None, + description=( + "List of unit or apparatus identifiers responding to the incident. " + "If multiple units are named, list all of them." + ), + examples=[["Engine 12", "Truck 4"], ["Unit 7", "Medic 3"]], + ) + + # ------------------------------------------------------------------------- + # Incident classification + # ------------------------------------------------------------------------- + + incident_type: Optional[str] = Field( + default=None, + description=( + "Type or category of the incident. Use a short lowercase label. " + "Common values: wildfire, structure_fire, vehicle_accident, EMS, hazmat, rescue." + ), + examples=["wildfire", "structure_fire", "EMS", "vehicle_accident", "hazmat"], + ) + incident_severity: Optional[str] = Field( + default=None, + description="Severity or alarm level of the incident.", + examples=["low", "moderate", "high", "major"], + ) + + # Location + + location_address: Optional[str] = Field( + default=None, + description="Street address or nearest landmark address of the incident.", + examples=["1234 Oak Street", "Highway 1 near Mile Marker 42"], + ) + location_city: Optional[str] = Field( + default=None, + description="City where the incident occurred.", + examples=["San Francisco", "Half Moon Bay"], + ) + location_county: Optional[str] = Field( + default=None, + description="County where the incident occurred.", + examples=["San Mateo County", "Los Angeles County"], + ) + location_state: Optional[str] = Field( + default=None, + description="State abbreviation or full name where the incident occurred.", + examples=["CA", "California"], + ) + location_coordinates: Optional[str] = Field( + default=None, + description="GPS coordinates in decimal degrees format: 'latitude, longitude'.", + examples=["37.7749, -122.4194"], + ) + + # Timestamps (stored as strings — the LLM extracts them as spoken/written) + alarm_time: Optional[str] = Field( + default=None, + description=( + "Time the alarm was received or the incident was reported. " + "Use HH:MM (24-hour) if possible, or reproduce the exact phrasing." + ), + examples=["14:35", "2:35 PM", "1435 hours"], + ) + dispatch_time: Optional[str] = Field( + default=None, + description="Time units were dispatched to the incident.", + examples=["14:37", "1437 hours"], + ) + arrival_time: Optional[str] = Field( + default=None, + description="Time the first unit arrived on scene.", + examples=["14:45", "1445 hours"], + ) + controlled_time: Optional[str] = Field( + default=None, + description="Time the incident was declared under control.", + examples=["17:20", "1720 hours"], + ) + clear_time: Optional[str] = Field( + default=None, + description="Time all units cleared the scene.", + examples=["18:05", "1805 hours"], + ) + incident_date: Optional[str] = Field( + default=None, + description="Date the incident occurred.", + examples=["2024-01-01", "January 1, 2024", "01/01/2024"], + ) + + # Personnel + supervisor: Optional[str] = Field( + default=None, + description=( + "Name or identifier of the incident commander or supervising officer " + "in charge at the scene." + ), + examples=["Battalion Chief Johnson", "Captain Maria Torres"], + ) + personnel: Optional[list[str]] = Field( + default=None, + description=( + "Names or identifiers of all personnel assigned to the incident. " + "Include rank or role if mentioned." + ), + examples=[["FF Smith", "FF Jones", "Paramedic Lee"]], + ) + personnel_count: Optional[int] = Field( + default=None, + description="Total number of personnel responding to the incident.", + examples=[6, 12], + ) + + # Casualties & medical (EMS / law enforcement forms) + casualties: Optional[int] = Field( + default=None, + description="Total number of injured persons (civilians and/or responders).", + examples=[2, 0], + ) + fatalities: Optional[int] = Field( + default=None, + description="Total number of fatalities resulting from the incident.", + examples=[0, 1], + ) + patients_transported: Optional[int] = Field( + default=None, + description="Number of patients transported to a medical facility.", + examples=[1, 3], + ) + hospital_destination: Optional[str] = Field( + default=None, + description="Name or identifier of the hospital patients were transported to.", + examples=["SF General Hospital", "UCLA Medical Center"], + ) + patient_condition: Optional[str] = Field( + default=None, + description="Reported condition of the patient(s) at time of transport or care.", + examples=["stable", "critical", "deceased"], + ) + + # Narrative + narrative: Optional[str] = Field( + default=None, + description=( + "Free-text description of the incident: what happened, actions taken " + "by responders, and the outcome. Reproduce faithfully from the transcript." + ), + examples=[ + "Crews arrived to find a one-story structure with heavy smoke showing " + "from the roof. A primary search revealed no occupants. Fire was " + "knocked down within 20 minutes. Cause determined to be electrical." + ], + ) + + # Wildfire-specific (Cal Fire FIRESCOPE) + area_burned_acres: Optional[float] = Field( + default=None, + description="Estimated area burned in acres. Wildfire incidents only.", + examples=[12.5, 0.25], + ) + structures_threatened: Optional[int] = Field( + default=None, + description="Number of structures threatened by the incident.", + examples=[15, 0], + ) + structures_destroyed: Optional[int] = Field( + default=None, + description="Number of structures destroyed.", + examples=[3, 0], + ) + containment_percent: Optional[int] = Field( + default=None, + ge=0, + le=100, + description="Percentage of wildfire perimeter that is contained (0–100).", + examples=[25, 100], + ) + fuel_type: Optional[str] = Field( + default=None, + description="Predominant fuel type involved in the wildfire.", + examples=["grass", "chaparral", "timber", "brush"], + ) + + # Law enforcement (incident/crime report forms) + case_number: Optional[str] = Field( + default=None, + description="Law enforcement case or report number.", + examples=["SFPD-2024-00123"], + ) + officer_id: Optional[str] = Field( + default=None, + description="Badge number or identifier of the reporting officer.", + examples=["Badge #4821"], + ) + suspect_description: Optional[str] = Field( + default=None, + description="Physical description or identifying information of any suspect.", + examples=["Male, approximately 30 years old, 6ft, wearing a red jacket"], + ) + + # Pipeline metadata — populated by the extraction pipeline, NOT the LLM + requires_review: list[str] = Field( + default_factory=list, + description=( + "Field names that could not be reliably extracted and require human " + "review before the PDF is submitted. Populated by the retry loop, " + "never by the LLM." + ), + ) + + @model_validator(mode="after") + def collect_missing_fields(self) -> IncidentReport: + """ + After model construction, populate `requires_review` with the names of + any core fields that are still None. This runs on every instantiation, + so it also catches fields that were validly set to None after a retry + loop exhausts all attempts. + + Only core operational fields are checked — metadata-only or + incident-type-specific fields (wildfire, law enforcement) are skipped + here; the TemplateMapper handles conditional field requirements. + """ + CORE_FIELDS = { + "incident_type", + "location_address", + "location_city", + "alarm_time", + "incident_date", + "supervisor", + "narrative", + } + missing = [ + field + for field in CORE_FIELDS + if getattr(self, field) is None + ] + # Only overwrite if the caller hasn't already populated requires_review + if not self.requires_review: + self.requires_review = missing + return self + + @classmethod + def llm_schema_hint(cls) -> dict: + """ + Return a JSON-serialisable dict describing each field that the LLM + should extract. Excludes `requires_review` (pipeline-only) and + incident-type-specific fields that may not apply to every call. + + Used to build the structured system prompt in LLM.build_prompt(). + """ + schema = cls.model_json_schema() + excluded = {"requires_review"} + properties = { + k: v + for k, v in schema.get("properties", {}).items() + if k not in excluded + } + return { + "type": "object", + "properties": properties, + "required": [], # all fields are optional — missing → requires_review + } diff --git a/src/template_mapper.py b/src/template_mapper.py new file mode 100644 index 0000000..9c5d7c3 --- /dev/null +++ b/src/template_mapper.py @@ -0,0 +1,219 @@ +""" +TemplateMapper — YAML-driven PDF field mapping engine. + +YAML Mapping Spec +----------------- +Each agency form is described by a YAML file placed in the `templates/` directory. +The format is: + + name: "Cal Fire FIRESCOPE" + pdf_path: "src/inputs/cal_fire_firescope_template.pdf" + fields: + - pdf_field_name: "IncidentID" + json_path: "incident_id" + + - pdf_field_name: "WildfireAcres" + json_path: "area_burned_acres" + condition: "incident_type == 'wildfire'" + +Field keys: + pdf_field_name (required) — exact field name as it appears in the PDF. + json_path (required) — dot-separated path into the IncidentReport model. + Supports top-level keys only for now (e.g. "supervisor"). + condition (optional) — boolean expression evaluated against the IncidentReport + dict. Field is skipped (left blank) when False. + +Supported condition operators: == != in not in and or not +No arbitrary Python is permitted — only Name, Constant, Compare, BoolOp, UnaryOp nodes +are allowed by the evaluator. Any other node raises ValueError. +""" + +from __future__ import annotations + +import ast +import logging +from pathlib import Path +from typing import Any + +import yaml + +from src.schemas.incident_report import IncidentReport + +logger = logging.getLogger(__name__) + + +class TemplateMapper: + """ + Loads a YAML agency mapping file and resolves IncidentReport field values + to the corresponding PDF form field names. + + Usage: + mapper = TemplateMapper("templates/cal_fire_firescope.yaml") + field_values = mapper.resolve(incident_report) + # field_values == {"IncidentID": "CAL-2024-001", "Location": "...", ...} + """ + + def __init__(self, yaml_path: str | Path) -> None: + self._yaml_path = Path(yaml_path) + self._config = self._load(self._yaml_path) + + # ------------------------------------------------------------------------- + # Public interface + # ------------------------------------------------------------------------- + + @property + def name(self) -> str: + return self._config["name"] + + @property + def pdf_path(self) -> str: + return self._config["pdf_path"] + + def resolve(self, report: IncidentReport) -> dict[str, Any]: + """ + Walk the YAML field list and resolve each entry against the report. + + Returns a dict of {pdf_field_name: value} for all fields whose + condition (if any) evaluates to True. Fields with a null/None value + are mapped to an empty string so the PDF writer does not raise. + """ + data = report.model_dump(exclude={"requires_review"}) + result: dict[str, Any] = {} + + for field_def in self._config.get("fields", []): + pdf_field = field_def["pdf_field_name"] + json_path = field_def["json_path"] + condition = field_def.get("condition") + + if condition: + try: + if not self._evaluate_condition(condition, data): + logger.debug("Skipping field %r — condition %r is False", pdf_field, condition) + continue + except ValueError as exc: + logger.warning("Bad condition on field %r: %s — skipping", pdf_field, exc) + continue + + value = self._resolve_path(json_path, data) + result[pdf_field] = self._to_string(value) + + return result + + # ------------------------------------------------------------------------- + # YAML loading + # ------------------------------------------------------------------------- + + def _load(self, path: Path) -> dict: + if not path.exists(): + raise FileNotFoundError(f"Template mapping not found: {path}") + with path.open("r", encoding="utf-8") as fh: + config = yaml.safe_load(fh) + self._validate_config(config, path) + return config + + def _validate_config(self, config: dict, path: Path) -> None: + for required in ("name", "pdf_path", "fields"): + if required not in config: + raise ValueError(f"YAML template {path} is missing required key: '{required}'") + for i, field in enumerate(config["fields"]): + for key in ("pdf_field_name", "json_path"): + if key not in field: + raise ValueError( + f"YAML template {path}, field[{i}] is missing required key: '{key}'" + ) + + # ------------------------------------------------------------------------- + # JSON path resolution + # ------------------------------------------------------------------------- + + def _resolve_path(self, json_path: str, data: dict) -> Any: + """ + Resolve a dot-separated path against the data dict. + e.g. "location_city" → data["location_city"] + Supports top-level keys only for now; returns None for missing paths. + """ + keys = json_path.split(".") + current = data + for key in keys: + if not isinstance(current, dict): + return None + current = current.get(key) + if current is None: + return None + return current + + @staticmethod + def _to_string(value: Any) -> str: + if value is None: + return "" + if isinstance(value, list): + return ", ".join(str(v) for v in value) + return str(value) + + # ------------------------------------------------------------------------- + # Safe condition evaluator (P2-4) + # ------------------------------------------------------------------------- + + def _evaluate_condition(self, condition: str, context: dict) -> bool: + """ + Safely evaluate a condition expression string against context. + + Only the following AST node types are permitted: + Compare, BoolOp (and/or), UnaryOp (not), Name, Constant. + + Any other node — including function calls, attribute access, or + subscripts — raises ValueError, preventing arbitrary code execution. + """ + try: + tree = ast.parse(condition, mode="eval") + except SyntaxError as exc: + raise ValueError(f"Invalid condition syntax: {condition!r}") from exc + + return self._eval_node(tree.body, context) + + def _eval_node(self, node: ast.expr, context: dict) -> Any: + if isinstance(node, ast.Compare): + left = self._eval_node(node.left, context) + for op, comparator in zip(node.ops, node.comparators): + right = self._eval_node(comparator, context) + if isinstance(op, ast.Eq): + result = left == right + elif isinstance(op, ast.NotEq): + result = left != right + elif isinstance(op, ast.In): + result = left in right + elif isinstance(op, ast.NotIn): + result = left not in right + elif isinstance(op, ast.Lt): + result = left < right + elif isinstance(op, ast.LtE): + result = left <= right + elif isinstance(op, ast.Gt): + result = left > right + elif isinstance(op, ast.GtE): + result = left >= right + else: + raise ValueError(f"Unsupported operator: {type(op).__name__}") + if not result: + return False + return True + + if isinstance(node, ast.BoolOp): + if isinstance(node.op, ast.And): + return all(self._eval_node(v, context) for v in node.values) + if isinstance(node.op, ast.Or): + return any(self._eval_node(v, context) for v in node.values) + + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.Not): + return not self._eval_node(node.operand, context) + + if isinstance(node, ast.Name): + return context.get(node.id) + + if isinstance(node, ast.Constant): + return node.value + + raise ValueError( + f"Expression node type {type(node).__name__!r} is not permitted in conditions. " + "Only comparisons, boolean operators, field names, and literal values are allowed." + ) diff --git a/src/transcriber.py b/src/transcriber.py new file mode 100644 index 0000000..4e63008 --- /dev/null +++ b/src/transcriber.py @@ -0,0 +1,124 @@ +""" +Transcriber — offline voice-to-text using OpenAI Whisper. + +Model sizes (set via WHISPER_MODEL env var): + tiny — fastest, lowest accuracy (~39 MB) + base — default, good balance (~74 MB) + small — better accuracy (~244 MB) + medium — high accuracy (~769 MB) + large — best accuracy (~1550 MB) + +FireForm is privacy-first: all transcription runs locally. +No audio data leaves the machine. +""" + +from __future__ import annotations + +import logging +import os +import tempfile +from pathlib import Path + +logger = logging.getLogger(__name__) + +SUPPORTED_FORMATS = {".wav", ".mp3", ".m4a", ".mp4", ".ogg", ".flac"} +_VALID_MODEL_SIZES = {"tiny", "base", "small", "medium", "large"} +DEFAULT_MODEL = "base" + + +class Transcriber: + """ + Wraps OpenAI Whisper for local, offline audio transcription. + + The Whisper model is lazy-loaded on the first call to transcribe() + so startup time is not penalised when the transcription feature + is not used. + + Usage: + t = Transcriber() # uses WHISPER_MODEL env var or "base" + text = t.transcribe("recording.wav") + + t = Transcriber(model_size="small") # explicit size + text = t.transcribe_bytes(audio_bytes, suffix=".mp3") + """ + + def __init__(self, model_size: str | None = None) -> None: + size = model_size or os.getenv("WHISPER_MODEL", DEFAULT_MODEL) + if size not in _VALID_MODEL_SIZES: + raise ValueError( + f"Invalid model size {size!r}. " + f"Choose from: {sorted(_VALID_MODEL_SIZES)}" + ) + self.model_size = size + self._model = None # lazy-loaded + + # ------------------------------------------------------------------------- + # Public API + # ------------------------------------------------------------------------- + + def transcribe(self, audio_path: str | Path) -> str: + """ + Transcribe an audio file at audio_path and return the plain text. + + Raises + ------ + FileNotFoundError : audio file does not exist. + ValueError : file format not supported. + """ + path = Path(audio_path) + + if not path.exists(): + raise FileNotFoundError(f"Audio file not found: {path}") + + if path.suffix.lower() not in SUPPORTED_FORMATS: + raise ValueError( + f"Unsupported audio format: {path.suffix!r}. " + f"Supported formats: {sorted(SUPPORTED_FORMATS)}" + ) + + model = self._load_model() + logger.info("Transcribing %s with Whisper/%s", path.name, self.model_size) + + result = model.transcribe(str(path)) + text = result["text"].strip() + + logger.info("Transcription complete — %d chars extracted", len(text)) + return text + + def transcribe_bytes(self, audio_bytes: bytes, suffix: str = ".wav") -> str: + """ + Transcribe raw audio bytes by writing them to a temp file first. + Used by the /transcribe endpoint which receives bytes from multipart uploads. + + The temp file is always deleted after transcription, success or failure. + """ + suffix = suffix if suffix.startswith(".") else f".{suffix}" + + with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: + tmp.write(audio_bytes) + tmp_path = Path(tmp.name) + + try: + return self.transcribe(tmp_path) + finally: + tmp_path.unlink(missing_ok=True) + + # ------------------------------------------------------------------------- + # Internal + # ------------------------------------------------------------------------- + + def _load_model(self): + if self._model is None: + try: + import whisper + except ImportError as exc: + raise ImportError( + "openai-whisper is not installed. " + "Run: pip install openai-whisper" + ) from exc + + logger.info("Loading Whisper model '%s' (first-time load)…", self.model_size) + self._model = whisper.load_model(self.model_size) + logger.info("Whisper model '%s' ready.", self.model_size) + + return self._model diff --git a/templates/employee_form.yaml b/templates/employee_form.yaml new file mode 100644 index 0000000..3ba138a --- /dev/null +++ b/templates/employee_form.yaml @@ -0,0 +1,32 @@ +# Employee Incident Report — sample template +# PDF source: src/inputs/file.pdf +# Field names verified via pdfrw field inspection. +# +# This is the sample form shipped with the repo for development and testing. +# Real agency templates (Cal Fire FIRESCOPE, EMS, law enforcement) will be +# added in Phase 2, Week 5 (tasks P2-6 through P2-8). + +name: "Employee Incident Report" +pdf_path: "src/inputs/file_template.pdf" + +fields: + - pdf_field_name: "NAME/SID" + json_path: "supervisor" + + - pdf_field_name: "JobTitle" + json_path: "incident_type" + + - pdf_field_name: "Department" + json_path: "agency_id" + + - pdf_field_name: "Phone Number" + json_path: "location_address" + + - pdf_field_name: "email" + json_path: "narrative" + + - pdf_field_name: "Date7_af_date" + json_path: "incident_date" + + - pdf_field_name: "signature" + json_path: "supervisor" diff --git a/tests/test_forms.py b/tests/test_forms.py index 8f432bf..21de801 100644 --- a/tests/test_forms.py +++ b/tests/test_forms.py @@ -1,25 +1,36 @@ def test_submit_form(client): - pass - # First create a template - # form_payload = { - # "template_id": 3, - # "input_text": "Hi. The employee's name is John Doe. His job title is managing director. His department supervisor is Jane Doe. His phone number is 123456. His email is jdoe@ucsc.edu. The signature is , and the date is 01/02/2005", - # } + from unittest.mock import patch, MagicMock - # template_res = client.post("/templates/", json=template_payload) - # template_id = template_res.json()["id"] + # Step 1: create a template to get a valid template_id + template_payload = { + "name": "Test Form", + "pdf_path": "src/inputs/file.pdf", + "fields": { + "name": "string", + "date": "string", + }, + } + template_res = client.post("/templates/create", json=template_payload) + assert template_res.status_code == 200 + template_id = template_res.json()["id"] - # # Submit a form - # form_payload = { - # "template_id": template_id, - # "data": {"rating": 5, "comment": "Great service"}, - # } + # Step 2: submit a form fill request — mock Controller to avoid hitting Ollama + fake_output_path = "src/outputs/file_20240101_120000_filled.pdf" + form_payload = { + "template_id": template_id, + "input_text": "Employee name is John Doe. Date is 01/01/2024.", + } - # response = client.post("/forms/", json=form_payload) + with patch("api.routes.forms.Controller") as MockController: + mock_ctrl = MagicMock() + mock_ctrl.fill_form.return_value = fake_output_path + MockController.return_value = mock_ctrl - # assert response.status_code == 200 + response = client.post("/forms/fill", json=form_payload) - # data = response.json() - # assert data["id"] is not None - # assert data["template_id"] == template_id - # assert data["data"] == form_payload["data"] + assert response.status_code == 200 + data = response.json() + assert data["id"] is not None + assert data["template_id"] == template_id + assert data["input_text"] == form_payload["input_text"] + assert data["output_pdf_path"] == fake_output_path diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py new file mode 100644 index 0000000..9e99935 --- /dev/null +++ b/tests/test_transcribe.py @@ -0,0 +1,160 @@ +""" +Tests for Phase 3: Voice Transcription Layer. + +All tests mock Whisper model loading so no model download is required. +The mock patches Transcriber._load_model at the class level, which means +any instance (including the route's singleton) will use the fake model. +""" + +import io +from unittest.mock import MagicMock, patch + +import pytest + +from src.transcriber import SUPPORTED_FORMATS, Transcriber + +FAKE_TRANSCRIPT = "Engine 12 responding to a structure fire at 1234 Oak Street." + + +# --------------------------------------------------------------------------- +# Unit tests — Transcriber class +# --------------------------------------------------------------------------- + + +def _make_mock_model(text: str = FAKE_TRANSCRIPT): + mock = MagicMock() + mock.transcribe.return_value = {"text": f" {text} "} + return mock + + +def test_transcriber_default_model_size(): + t = Transcriber() + assert t.model_size == "base" + + +def test_transcriber_custom_model_size(): + t = Transcriber(model_size="small") + assert t.model_size == "small" + + +def test_transcriber_invalid_model_size(): + with pytest.raises(ValueError, match="Invalid model size"): + Transcriber(model_size="giant") + + +def test_transcribe_strips_whitespace(tmp_path): + audio = tmp_path / "incident.wav" + audio.write_bytes(b"fake audio bytes") + + with patch.object(Transcriber, "_load_model", return_value=_make_mock_model()): + t = Transcriber() + result = t.transcribe(audio) + + assert result == FAKE_TRANSCRIPT # leading/trailing whitespace stripped + + +def test_transcribe_file_not_found(): + t = Transcriber() + with pytest.raises(FileNotFoundError): + t.transcribe("nonexistent_audio.wav") + + +def test_transcribe_unsupported_format(tmp_path): + bad_file = tmp_path / "report.txt" + bad_file.write_bytes(b"not audio") + + t = Transcriber() + with pytest.raises(ValueError, match="Unsupported audio format"): + t.transcribe(bad_file) + + +def test_transcribe_bytes_cleans_up_temp_file(): + audio_bytes = b"fake audio bytes" + + created_paths = [] + + original_transcribe = Transcriber.transcribe + + def capturing_transcribe(self, path): + created_paths.append(str(path)) + return FAKE_TRANSCRIPT + + with patch.object(Transcriber, "_load_model", return_value=_make_mock_model()): + with patch.object(Transcriber, "transcribe", capturing_transcribe): + t = Transcriber() + result = t.transcribe_bytes(audio_bytes, suffix=".wav") + + assert result == FAKE_TRANSCRIPT + # Temp file must have been deleted + import os + for p in created_paths: + assert not os.path.exists(p), f"Temp file was not cleaned up: {p}" + + +def test_supported_formats_coverage(): + expected = {".wav", ".mp3", ".m4a", ".mp4", ".ogg", ".flac"} + assert expected == SUPPORTED_FORMATS + + +# --------------------------------------------------------------------------- +# Endpoint tests — POST /transcribe +# --------------------------------------------------------------------------- + + +def test_transcribe_endpoint_success(client): + with patch("api.routes.transcribe._get_transcriber") as mock_getter: + mock_transcriber = MagicMock() + mock_transcriber.transcribe_bytes.return_value = FAKE_TRANSCRIPT + mock_transcriber.model_size = "base" + mock_getter.return_value = mock_transcriber + + response = client.post( + "/transcribe", + files={"file": ("incident.wav", io.BytesIO(b"fake audio"), "audio/wav")}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["text"] == FAKE_TRANSCRIPT + assert data["audio_filename"] == "incident.wav" + assert data["model_used"] == "base" + + +def test_transcribe_endpoint_unsupported_format(client): + response = client.post( + "/transcribe", + files={"file": ("report.txt", io.BytesIO(b"text"), "text/plain")}, + ) + assert response.status_code == 415 + assert "Unsupported audio format" in response.json()["detail"] + + +def test_transcribe_endpoint_transcriber_error(client): + with patch("api.routes.transcribe._get_transcriber") as mock_getter: + mock_transcriber = MagicMock() + mock_transcriber.transcribe_bytes.side_effect = RuntimeError("Whisper failed") + mock_getter.return_value = mock_transcriber + + response = client.post( + "/transcribe", + files={"file": ("incident.mp3", io.BytesIO(b"audio"), "audio/mpeg")}, + ) + + assert response.status_code == 500 + assert "Whisper failed" in response.json()["detail"] + + +@pytest.mark.parametrize("fmt", [".wav", ".mp3", ".m4a", ".mp4", ".ogg", ".flac"]) +def test_transcribe_endpoint_accepts_all_supported_formats(client, fmt): + with patch("api.routes.transcribe._get_transcriber") as mock_getter: + mock_transcriber = MagicMock() + mock_transcriber.transcribe_bytes.return_value = FAKE_TRANSCRIPT + mock_transcriber.model_size = "base" + mock_getter.return_value = mock_transcriber + + response = client.post( + "/transcribe", + files={"file": (f"audio{fmt}", io.BytesIO(b"audio"), "audio/octet-stream")}, + ) + + assert response.status_code == 200