Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 92 additions & 114 deletions src/llm.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,113 @@
import json
import logging
import os
import requests

import requests

class LLM:
def __init__(self, transcript_text=None, target_fields=None, json=None):
if json is None:
json = {}
self._transcript_text = transcript_text # str
self._target_fields = target_fields # List, contains the template field.
self._json = json # dictionary

def type_check_all(self):
if type(self._transcript_text) is not str:
raise TypeError(
f"ERROR in LLM() attributes ->\
Transcript must be text. Input:\n\ttranscript_text: {self._transcript_text}"
)
elif type(self._target_fields) is not list:
raise TypeError(
f"ERROR in LLM() attributes ->\
Target fields must be a list. Input:\n\ttarget_fields: {self._target_fields}"
)
from src.schemas.incident_report import IncidentReport

def build_prompt(self, current_field):
"""
This method is in charge of the prompt engineering. It creates a specific prompt for each target field.
@params: current_field -> represents the current element of the json that is being prompted.
"""
prompt = f"""
SYSTEM PROMPT:
You are an AI assistant designed to help fillout json files with information extracted from transcribed voice recordings.
You will receive the transcription, and the name of the JSON field whose value you have to identify in the context. Return
only a single string containing the identified value for the JSON field.
If the field name is plural, and you identify more than one possible value in the text, return both separated by a ";".
If you don't identify the value in the provided text, return "-1".
---
DATA:
Target JSON field to find in text: {current_field}

TEXT: {self._transcript_text}
"""

return prompt

def main_loop(self):
# self.type_check_all()
for field in self._target_fields.keys():
prompt = self.build_prompt(field)
# print(prompt)
# ollama_url = "http://localhost:11434/api/generate"
ollama_host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
ollama_url = f"{ollama_host}/api/generate"

payload = {
"model": "mistral",
"prompt": prompt,
"stream": False, # don't really know why --> look into this later.
}

try:
response = requests.post(ollama_url, json=payload)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {ollama_url}. "
"Please ensure Ollama is running and accessible."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}")

# parse response
json_data = response.json()
parsed_response = json_data["response"]
# print(parsed_response)
self.add_response_to_json(field, parsed_response)

print("----------------------------------")
print("\t[LOG] Resulting JSON created from the input text:")
print(json.dumps(self._json, indent=2))
print("--------- extracted data ---------")
logger = logging.getLogger(__name__)

return self

def add_response_to_json(self, field, value):
"""
this method adds the following value under the specified field,
or under a new field if the field doesn't exist, to the json dict
class LLM:
def __init__(self, transcript_text=None, target_fields=None):
self._transcript_text = transcript_text
# target_fields kept for backward compatibility with FileManipulator;
# the new extraction uses IncidentReport.llm_schema_hint() instead.
self._target_fields = target_fields
self._result: IncidentReport | None = None

def build_prompt(self) -> str:
"""
value = value.strip().replace('"', "")
parsed_value = None
Build a single structured prompt for Ollama's /api/generate endpoint.

if value != "-1":
parsed_value = value
Uses IncidentReport.llm_schema_hint() as the schema constraint so the
model always returns a JSON object with the correct field names.
Mistral instruction format ([INST] … [/INST]) is used so the system
context is respected.
"""
schema_hint = json.dumps(IncidentReport.llm_schema_hint(), indent=2)

system = (
"You are an AI assistant that extracts structured incident report data "
"from transcribed voice recordings made by first responders.\n\n"
"Return ONLY a valid JSON object that matches the schema below. "
"Do not include any explanation, markdown code fences, or extra text.\n\n"
f"Schema:\n{schema_hint}\n\n"
"Rules:\n"
" - Set a field to null if the information is not present in the transcript.\n"
" - For list fields (e.g. unit_ids, personnel), return a JSON array.\n"
" - For timestamps, reproduce the exact phrasing from the transcript "
"(e.g. '14:35', '1435 hours', '2:35 PM').\n"
" - Do not invent or infer values that are not explicitly stated."
)

if ";" in value:
parsed_value = self.handle_plural_values(value)
user = (
"Extract all incident report fields from the following transcript:\n\n"
f"{self._transcript_text}"
)

if field in self._json.keys():
self._json[field].append(parsed_value)
else:
self._json[field] = parsed_value
return f"[INST] {system}\n\n{user} [/INST]"

return
# Ollama request
def _ollama_url(self) -> str:
host = os.getenv("OLLAMA_HOST", "http://localhost:11434").rstrip("/")
return f"{host}/api/generate"

def handle_plural_values(self, plural_value):
def main_loop(self) -> "LLM":
"""
This method handles plural values.
Takes in strings of the form 'value1; value2; value3; ...; valueN'
returns a list with the respective values -> [value1, value2, value3, ..., valueN]
Send a single structured request to Ollama and parse the response as an
IncidentReport. Replaces the old per-field loop.

On success, self._result holds the validated IncidentReport.
Fields that could not be extracted are listed in result.requires_review.
"""
if ";" not in plural_value:
raise ValueError(
f"Value is not plural, doesn't have ; separator, Value: {plural_value}"
prompt = self.build_prompt()
model = os.getenv("OLLAMA_MODEL", "mistral")

payload = {
"model": model,
"prompt": prompt,
"format": "json", # constrains Ollama output to valid JSON
"stream": False,
}

try:
response = requests.post(self._ollama_url(), json=payload)
response.raise_for_status()
except requests.exceptions.ConnectionError:
raise ConnectionError(
f"Could not connect to Ollama at {self._ollama_url()}. "
"Please ensure Ollama is running and accessible."
)
except requests.exceptions.HTTPError as e:
raise RuntimeError(f"Ollama returned an error: {e}") from e

print(
f"\t[LOG]: Formating plural values for JSON, [For input {plural_value}]..."
)
values = plural_value.split(";")
raw = response.json()["response"]

# Remove trailing leading whitespace
for i in range(len(values)):
current = i + 1
if current < len(values):
clean_value = values[current].lstrip()
values[current] = clean_value
try:
extracted = json.loads(raw)
except json.JSONDecodeError as e:
logger.warning("Ollama returned invalid JSON: %s. Raw (first 200 chars): %.200s", e, raw)
extracted = {}

print(f"\t[LOG]: Resulting formatted list of values: {values}")
# Pydantic validates and auto-populates requires_review for missing fields
self._result = IncidentReport(**extracted)

return values
logger.info("Extraction complete. requires_review: %s", self._result.requires_review)

return self

# Data accessors
def get_data(self) -> dict:
"""
Return extracted data as a plain dict for use by Filler.
requires_review is excluded here — use get_report() to access it.
"""
if self._result is None:
return {}
return self._result.model_dump(exclude={"requires_review"})

def get_data(self):
return self._json
def get_report(self) -> IncidentReport | None:
"""Return the full IncidentReport including requires_review."""
return self._result
3 changes: 3 additions & 0 deletions src/schemas/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from src.schemas.incident_report import IncidentReport

__all__ = ["IncidentReport"]
Loading