Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lettucedetect/detectors/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,9 @@ def make_detector(method: str, **kwargs) -> BaseDetector:
from lettucedetect.detectors.llm import LLMDetector

return LLMDetector(**kwargs)
elif method == "number":
from lettucedetect.detectors.number import NumberDetector

return NumberDetector()
else:
raise ValueError(f"Unknown detector method: {method}. Use one of: transformer, llm")
84 changes: 84 additions & 0 deletions lettucedetect/detectors/number.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from lettucedetect.detectors.base import BaseDetector
import re


class NumberDetector(BaseDetector):
"""Detect hallucinations only if a number appears in the answer not mentioned in the context."""

def predict(
self,
context: list[str],
answer: str,
question = None,
output_format: str = "tokens",
) -> list:
"""
Detect number hallucinations by comparing numbers in the answer with the context.

:param context: A list of context strings.
:param answer: The answer string.
:param output_format: 'tokens' for word-level or 'spans' for sentence-level results.
:return: List of hallucinated spans or tokens.
"""
context_text = " ".join(context).lower()
hallucinated_numbers = self._detect_number_hallucinations(context_text, answer)

if output_format == "tokens":
return self._get_token_level_predictions(answer, hallucinated_numbers)
elif output_format == "spans":
return self._get_span_level_predictions(answer, hallucinated_numbers)
else:
raise ValueError("Invalid output_format. Use 'tokens' or 'spans'.")

def _extract_numbers(self, text: str) -> set:
number_pattern = r'\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d+\.\d+|\d+'
return set(re.findall(number_pattern, text))

def _detect_number_hallucinations(self, context: str, answer: str) -> set:
context_numbers = self._extract_numbers(context)
answer_numbers = self._extract_numbers(answer)
return {num for num in answer_numbers if num not in context_numbers}

def _get_token_level_predictions(self, answer: str, hallucinated_numbers: set) -> list:
results = []
for match in re.finditer(r'\b\w+\b', answer):
token = match.group()
is_hallucinated = token in hallucinated_numbers
results.append({
"token": token,
"pred": int(is_hallucinated),
"prob": 1.0 if is_hallucinated else 0.0,
})
return results

def _get_span_level_predictions(self, answer: str, hallucinated_numbers: set) -> list:
spans = []
for match in re.finditer(r'[^.?!]+[.?!]', answer):
sentence = match.group().strip()
sentence_numbers = self._extract_numbers(sentence)
if any(num in hallucinated_numbers for num in sentence_numbers):
spans.append({
"text": sentence,
"start": match.start(),
"end": match.end(),
"confidence": 1.0,
})
return spans

def predict_prompt(self, prompt, answer, output_format="tokens") -> list:
"""Predict hallucination tokens or spans from the provided prompt and answer.

:param prompt: The prompt string.
:param answer: The answer string.
:param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
"""
return self._predict(prompt, answer, output_format)

def predict_prompt_batch(self, prompts, answers, output_format="tokens") -> list:
"""Predict hallucination tokens or spans from the provided prompts and answers.

:param prompts: List of prompt strings.
:param answers: List of answer strings.
:param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
"""
return [self._predict(p, a, output_format) for p, a in zip(prompts, answers)]
41 changes: 41 additions & 0 deletions scripts/number_dataset_creation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import re
import json
from concurrent.futures import ThreadPoolExecutor

def extract_numbers(text):
"""Extract all numbers (integers and floats) from the text."""
return set(re.findall(r'\d+(?:\.\d+)?', text))

def is_hallucinated_numbers(batch):
"""Return True if batch contains a hallucinated number."""
context = batch['prompt']
answer = batch['answer']

context_numbers = extract_numbers(context)
answer_numbers = extract_numbers(answer)

hallucinated = [num for num in answer_numbers if num not in context_numbers]

return bool(hallucinated)

def filter_batches_with_hallucinated_numbers(dataset):
"""Filter batches with hallucinated numbers in parallel."""
filtered = []
with ThreadPoolExecutor() as executor:
results = list(executor.map(is_hallucinated_numbers, dataset))
for batch, has_hallucination in zip(dataset, results):
if has_hallucination:
filtered.append(batch)
return filtered

# Example usage:
if __name__ == "__main__":
# Load your dataset here (replace with actual loading if reading from file)
with open("../data/ragtruth/ragtruth_data/.json", "r") as f:
dataset = json.load(f)

hallucinated_batches = filter_batches_with_hallucinated_numbers(dataset)

# Save the filtered hallucinated-number-only dataset
with open("hallucinated_numbers_only.json", "w") as f:
json.dump(hallucinated_batches, f, indent=2)
39 changes: 39 additions & 0 deletions scripts/token_generation_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import re
import json

from lettucedetect.detectors.number import NumberDetector

def simulate_llm_generation(batch, detector: NumberDetector):
context = [batch["prompt"]]
gold_answer = batch["answer"]
tokens = re.findall(r'\b\w+\b|\S', gold_answer) # Better token splitting

# print(f"\nSimulating generation for:\nPrompt: {batch['prompt']}\n")
# generated = ""

for idx, token in enumerate(tokens):
token_preds = detector.predict(context, token, output_format="tokens")
hallucinated = [pred for pred in token_preds if pred['pred'] == 1]
if hallucinated:
print(f"[Step {idx+1}] Token: '{token}' -> Hallucinated Number Detected!")
else:
print(f"[Step {idx+1}] Token: '{token}'")

# print("\nFinal Answer Generated:\n", generated.strip())

if __name__ == "__main__":
# Load data
with open("../data/ragtruth/hallucinated_numbers_only.json", "r") as f:
dataset = json.load(f)

# Instantiate the hallucination detector
detector = NumberDetector()

# Simulate LLM generation on one example for debugging
simulate_llm_generation(dataset[2], detector)


# Run over every experiment
# Suggestion to comment else-statement in method and add break in if-statement
# for batch in dataset:
# simulate_llm_generation(batch, detector)
Loading