diff --git a/lettucedetect/detectors/factory.py b/lettucedetect/detectors/factory.py index 430e616..2b27434 100644 --- a/lettucedetect/detectors/factory.py +++ b/lettucedetect/detectors/factory.py @@ -23,5 +23,9 @@ def make_detector(method: str, **kwargs) -> BaseDetector: from lettucedetect.detectors.llm import LLMDetector return LLMDetector(**kwargs) + elif method == "number": + from lettucedetect.detectors.number import NumberDetector + + return NumberDetector() else: raise ValueError(f"Unknown detector method: {method}. Use one of: transformer, llm") diff --git a/lettucedetect/detectors/number.py b/lettucedetect/detectors/number.py new file mode 100644 index 0000000..d261c00 --- /dev/null +++ b/lettucedetect/detectors/number.py @@ -0,0 +1,84 @@ +from lettucedetect.detectors.base import BaseDetector +import re + + +class NumberDetector(BaseDetector): + """Detect hallucinations only if a number appears in the answer not mentioned in the context.""" + + def predict( + self, + context: list[str], + answer: str, + question = None, + output_format: str = "tokens", + ) -> list: + """ + Detect number hallucinations by comparing numbers in the answer with the context. + + :param context: A list of context strings. + :param answer: The answer string. + :param output_format: 'tokens' for word-level or 'spans' for sentence-level results. + :return: List of hallucinated spans or tokens. + """ + context_text = " ".join(context).lower() + hallucinated_numbers = self._detect_number_hallucinations(context_text, answer) + + if output_format == "tokens": + return self._get_token_level_predictions(answer, hallucinated_numbers) + elif output_format == "spans": + return self._get_span_level_predictions(answer, hallucinated_numbers) + else: + raise ValueError("Invalid output_format. Use 'tokens' or 'spans'.") + + def _extract_numbers(self, text: str) -> set: + number_pattern = r'\d{1,3}(?:,\d{3})*(?:\.\d+)?|\d+\.\d+|\d+' + return set(re.findall(number_pattern, text)) + + def _detect_number_hallucinations(self, context: str, answer: str) -> set: + context_numbers = self._extract_numbers(context) + answer_numbers = self._extract_numbers(answer) + return {num for num in answer_numbers if num not in context_numbers} + + def _get_token_level_predictions(self, answer: str, hallucinated_numbers: set) -> list: + results = [] + for match in re.finditer(r'\b\w+\b', answer): + token = match.group() + is_hallucinated = token in hallucinated_numbers + results.append({ + "token": token, + "pred": int(is_hallucinated), + "prob": 1.0 if is_hallucinated else 0.0, + }) + return results + + def _get_span_level_predictions(self, answer: str, hallucinated_numbers: set) -> list: + spans = [] + for match in re.finditer(r'[^.?!]+[.?!]', answer): + sentence = match.group().strip() + sentence_numbers = self._extract_numbers(sentence) + if any(num in hallucinated_numbers for num in sentence_numbers): + spans.append({ + "text": sentence, + "start": match.start(), + "end": match.end(), + "confidence": 1.0, + }) + return spans + + def predict_prompt(self, prompt, answer, output_format="tokens") -> list: + """Predict hallucination tokens or spans from the provided prompt and answer. + + :param prompt: The prompt string. + :param answer: The answer string. + :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + """ + return self._predict(prompt, answer, output_format) + + def predict_prompt_batch(self, prompts, answers, output_format="tokens") -> list: + """Predict hallucination tokens or spans from the provided prompts and answers. + + :param prompts: List of prompt strings. + :param answers: List of answer strings. + :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + """ + return [self._predict(p, a, output_format) for p, a in zip(prompts, answers)] \ No newline at end of file diff --git a/scripts/number_dataset_creation.py b/scripts/number_dataset_creation.py new file mode 100644 index 0000000..7e997e4 --- /dev/null +++ b/scripts/number_dataset_creation.py @@ -0,0 +1,41 @@ +import re +import json +from concurrent.futures import ThreadPoolExecutor + +def extract_numbers(text): + """Extract all numbers (integers and floats) from the text.""" + return set(re.findall(r'\d+(?:\.\d+)?', text)) + +def is_hallucinated_numbers(batch): + """Return True if batch contains a hallucinated number.""" + context = batch['prompt'] + answer = batch['answer'] + + context_numbers = extract_numbers(context) + answer_numbers = extract_numbers(answer) + + hallucinated = [num for num in answer_numbers if num not in context_numbers] + + return bool(hallucinated) + +def filter_batches_with_hallucinated_numbers(dataset): + """Filter batches with hallucinated numbers in parallel.""" + filtered = [] + with ThreadPoolExecutor() as executor: + results = list(executor.map(is_hallucinated_numbers, dataset)) + for batch, has_hallucination in zip(dataset, results): + if has_hallucination: + filtered.append(batch) + return filtered + +# Example usage: +if __name__ == "__main__": + # Load your dataset here (replace with actual loading if reading from file) + with open("../data/ragtruth/ragtruth_data/.json", "r") as f: + dataset = json.load(f) + + hallucinated_batches = filter_batches_with_hallucinated_numbers(dataset) + + # Save the filtered hallucinated-number-only dataset + with open("hallucinated_numbers_only.json", "w") as f: + json.dump(hallucinated_batches, f, indent=2) diff --git a/scripts/token_generation_simulator.py b/scripts/token_generation_simulator.py new file mode 100644 index 0000000..1b6d348 --- /dev/null +++ b/scripts/token_generation_simulator.py @@ -0,0 +1,39 @@ +import re +import json + +from lettucedetect.detectors.number import NumberDetector + +def simulate_llm_generation(batch, detector: NumberDetector): + context = [batch["prompt"]] + gold_answer = batch["answer"] + tokens = re.findall(r'\b\w+\b|\S', gold_answer) # Better token splitting + + # print(f"\nSimulating generation for:\nPrompt: {batch['prompt']}\n") + # generated = "" + + for idx, token in enumerate(tokens): + token_preds = detector.predict(context, token, output_format="tokens") + hallucinated = [pred for pred in token_preds if pred['pred'] == 1] + if hallucinated: + print(f"[Step {idx+1}] Token: '{token}' -> Hallucinated Number Detected!") + else: + print(f"[Step {idx+1}] Token: '{token}'") + + # print("\nFinal Answer Generated:\n", generated.strip()) + +if __name__ == "__main__": + # Load data + with open("../data/ragtruth/hallucinated_numbers_only.json", "r") as f: + dataset = json.load(f) + + # Instantiate the hallucination detector + detector = NumberDetector() + + # Simulate LLM generation on one example for debugging + simulate_llm_generation(dataset[2], detector) + + + # Run over every experiment + # Suggestion to comment else-statement in method and add break in if-statement + # for batch in dataset: + # simulate_llm_generation(batch, detector) \ No newline at end of file