diff --git a/README.md b/README.md
index 232f162..bda443a 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@
Because even AI needs a reality check! 🥬
-LettuceDetect is a lightweight and efficient tool for detecting hallucinations in Retrieval-Augmented Generation (RAG) systems. It identifies unsupported parts of an answer by comparing it to the provided context. The tool is trained and evaluated on the [RAGTruth](https://aclanthology.org/2024.acl-long.585/) dataset and leverages [ModernBERT](https://github.com/AnswerDotAI/ModernBERT) for long-context processing, making it ideal for tasks requiring extensive context windows.
+LettuceDetect is a lightweight and efficient tool for detecting hallucinations in Retrieval-Augmented Generation (RAG) systems. It identifies unsupported parts of an answer by comparing it to the provided context. The tool is trained and evaluated on the [RAGTruth](https://aclanthology.org/2024.acl-long.585/) dataset and leverages [ModernBERT](https://github.com/AnswerDotAI/ModernBERT) for English and [EuroBERT](https://huggingface.co/blog/EuroBERT/release) for multilingual support, making it ideal for tasks requiring extensive context windows.
Our models are inspired from the [Luna](https://aclanthology.org/2025.coling-industry.34/) paper which is an encoder-based model and uses a similar token-level approach.
@@ -21,9 +21,15 @@ Our models are inspired from the [Luna](https://aclanthology.org/2025.coling-ind
- LettuceDetect addresses two critical limitations of existing hallucination detection models:
- Context window constraints of traditional encoder-based methods
- Computational inefficiency of LLM-based approaches
-- Our models currently **outperforms** all other encoder-based and prompt-based models on the RAGTruth dataset and are significantly faster and smaller
+- Our models currently **outperform** all other encoder-based and prompt-based models on the RAGTruth dataset and are significantly faster and smaller
- Achieves higher score than some fine-tuned LLMs e.g. LLAMA-2-13B presented in [RAGTruth](https://aclanthology.org/2024.acl-long.585/), coming up just short of the LLM fine-tuned in the [RAG-HAT paper](https://aclanthology.org/2024.emnlp-industry.113.pdf)
-- We release the code, the model and the tool under the **MIT license**
+
+## 🚀 Latest Updates
+
+- **May 18, 2025** - Released version **0.1.7**: Multilingual support (thanks to EuroBERT) for 7 languages: English, German, French, Spanish, Italian, Polish, and Chinese!
+- Up to **17 F1 points improvement** over baseline LLM judges like GPT-4.1-mini across different languages
+- **EuroBERT models**: We've trained base/210M (faster) and large/610M (more accurate) variants
+- You can now also use **LLM baselines** for hallucination detection (see below)
## Get going
@@ -31,7 +37,8 @@ Our models are inspired from the [Luna](https://aclanthology.org/2025.coling-ind
- ✨ **Token-level precision**: detect exact hallucinated spans
- 🚀 **Optimized for inference**: smaller model size and faster inference
-- 🧠 **4K context window** via ModernBERT
+- 🧠 **Long context window** support (4K for ModernBERT, 8K for EuroBERT)
+- 🌍 **Multilingual support**: 7 languages covered
- ⚖️ **MIT-licensed** models & code
- 🤖 **HF Integration**: one-line model loading
- 📦 **Easy to use python API**: can be downloaded from pip and few lines of code to integrate into your RAG system
@@ -45,25 +52,42 @@ pip install -e .
From pip:
```bash
-pip install lettucedetect
+pip install lettucedetect -U
```
### Quick Start
Check out our models published to Huggingface:
-- lettucedetect-base: https://huggingface.co/KRLabsOrg/lettucedect-base-modernbert-en-v1
-- lettucedetect-large: https://huggingface.co/KRLabsOrg/lettucedect-large-modernbert-en-v1
+
+**English Models**:
+- Base: [KRLabsOrg/lettucedetect-base-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-base-modernbert-en-v1)
+- Large: [KRLabsOrg/lettucedetect-large-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-large-modernbert-en-v1)
+
+**Multilingual Models**:
+We've trained 210m and 610m variants of EuroBERT, see our HuggingFace collection: [HF models](https://huggingface.co/collections/KRLabsOrg/multilingual-hallucination-detection-682a2549c18ecd32689231ce)
+
+
+*See the full list of models and smaller variants in our [HuggingFace page](https://huggingface.co/KRLabsOrg).*
You can get started right away with just a few lines of code.
```python
from lettucedetect.models.inference import HallucinationDetector
-# For a transformer-based approach:
+# For English:
detector = HallucinationDetector(
- method="transformer", model_path="KRLabsOrg/lettucedect-base-modernbert-en-v1"
+ method="transformer",
+ model_path="KRLabsOrg/lettucedect-base-modernbert-en-v1",
)
+# For other languages (e.g., German):
+# detector = HallucinationDetector(
+# method="transformer",
+# model_path="KRLabsOrg/lettucedect-210m-eurobert-de-v1",
+# lang="de",
+# trust_remote_code=True
+# )
+
contexts = ["France is a country in Europe. The capital of France is Paris. The population of France is 67 million.",]
question = "What is the capital of France? What is the population of France?"
answer = "The capital of France is Paris. The population of France is 69 million."
@@ -75,26 +99,39 @@ print("Predictions:", predictions)
# Predictions: [{'start': 31, 'end': 71, 'confidence': 0.9944414496421814, 'text': ' The population of France is 69 million.'}]
```
-## Performance
+Check out our [HF collection](https://huggingface.co/collections/KRLabsOrg/multilingual-hallucination-detection-682a2549c18ecd32689231ce) for more examples.
-**Example level results**
+We also implemented LLM-based baselines, for that add your OpenAI API key:
-We evaluate our model on the test set of the [RAGTruth](https://aclanthology.org/2024.acl-long.585/) dataset. Our large model, **lettucedetect-large-v1**, achieves an overall F1 score of 79.22%, outperforming prompt-based methods like GPT-4 (63.4%) and encoder-based models like [Luna](https://aclanthology.org/2025.coling-industry.34.pdf) (65.4%). It also surpasses fine-tuned LLAMA-2-13B (78.7%) (presented in [RAGTruth](https://aclanthology.org/2024.acl-long.585/)) and is competitive with the SOTA fine-tuned LLAMA-3-8B (83.9%) (presented in the [RAG-HAT paper](https://aclanthology.org/2024.emnlp-industry.113.pdf)). Overall, **lettucedetect-large-v1** and **lettucedect-base-v1** are very performant models, while being very effective in inference settings.
+```bash
+export OPENAI_API_KEY=your_api_key
+```
-The results on the example-level can be seen in the table below.
+Then in code:
-
-
-
+```python
+from lettucedetect.models.inference import HallucinationDetector
-**Span-level results**
+# For German:
+detector = HallucinationDetector(method="llm", lang="de")
-At the span level, our model achieves the best scores across all data types, significantly outperforming previous models. The results can be seen in the table below. Note that here we don't compare to models, like [RAG-HAT](https://aclanthology.org/2024.emnlp-industry.113.pdf), since they have no span-level evaluation presented.
+# Then predict the same way
+predictions = detector.predict(context=contexts, question=question, answer=answer, output_format="spans")
+```
-
-
-
+## Performance
+
+We've evaluated our models against both encoder-based and LLM-based approaches. The key findings include:
+- In English, our model **outperform** all other encoder-based and prompt-based models on the RAGTruth dataset and are significantly faster and smaller
+- Our multilingual models are better than baseline LLM judges like GPT-4.1-mini
+- Our models are also significantly faster and smaller than the LLM-based judges
+
+For detailed performance metrics and evaluations of our models:
+- [English model documentation](docs/README.md)
+- [Multilingual model documentation](docs/EUROBERT.md)
+- [Paper](https://arxiv.org/abs/2502.17125)
+- [Model cards](https://huggingface.co/KRLabsOrg)
## How does it work?
@@ -229,11 +266,11 @@ positional arguments:
options:
-h, --help show this help message and exit
--model MODEL Path or huggingface URL to the model. The default value is
- "KRLabsOrg/lettucedect-base-modernbert-en-v1".
+ "KRLabsOrg/lettucedetect-base-modernbert-en-v1".
--method {transformer}
Hallucination detection method. The default value is
"transformer".
-````
+```
Example using the python client library:
diff --git a/assets/lettuce_detective_multi.png b/assets/lettuce_detective_multi.png
new file mode 100644
index 0000000..d013491
Binary files /dev/null and b/assets/lettuce_detective_multi.png differ
diff --git a/docs/EUROBERT.md b/docs/EUROBERT.md
new file mode 100644
index 0000000..0380d8c
--- /dev/null
+++ b/docs/EUROBERT.md
@@ -0,0 +1,239 @@
+# 🥬 LettuceDetect Goes Multilingual: Fine-tuning EuroBERT on Synthetic RAGTruth Translations
+
+
+
+
+ Expanding hallucination detection across languages for RAG pipelines.
+
+
+---
+
+## 🏷️ TL;DR
+
+- We present the first multilingual hallucination detection framework for Retrieval-Augmented Generation (RAG).
+- We translated the [RAGTruth dataset](https://arxiv.org/abs/2401.00396) into German, French, Italian, Spanish, Polish, and Chinese while preserving hallucination annotations.
+- We fine-tuned [**EuroBERT**](https://huggingface.co/blog/EuroBERT/release) for token-level hallucination detection across all these languages.
+- Our experiments show that **EuroBERT** significantly outperforms prompt-based LLM judges like GPT-4.1-mini by up to **17 F1 points**.
+- All translated datasets, fine-tuned models, and translation scripts are available under MIT license.
+
+---
+
+## Quick Links
+
+- **GitHub**: [github.com/KRLabsOrg/LettuceDetect](https://github.com/KRLabsOrg/LettuceDetect)
+- **PyPI**: [pypi.org/project/lettucedetect](https://pypi.org/project/lettucedetect/)
+- **arXiv Paper**: [2502.17125](https://arxiv.org/abs/2502.17125)
+- **Hugging Face Models**:
+ - [Our HF collection](https://huggingface.co/collections/KRLabsOrg/multilingual-hallucination-detection-682a2549c18ecd32689231ce)
+- **Demo**: [Hugging Face Space](https://huggingface.co/spaces/KRLabsOrg/LettuceDetect-Multilingual)
+
+
+## Get Started
+
+Install the package:
+
+```bash
+pip install lettucedetect
+```
+
+### Transformer-based Hallucination Detection (German)
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+detector = HallucinationDetector(
+ method="transformer",
+ model_path="KRLabsOrg/lettucedect-610m-eurobert-de-v1",
+ lang="de",
+ trust_remote_code=True
+)
+
+contexts = [
+ "Frankreich ist ein Land in Europa. Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 67 Millionen."
+]
+question = "Was ist die Hauptstadt von Frankreich? Wie groß ist die Bevölkerung Frankreichs?"
+answer = "Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 69 Millionen."
+
+predictions = detector.predict(context=contexts, question=question, answer=answer, output_format="spans")
+print("Predictions:", predictions)
+```
+
+### LLM-based Hallucination Detection (Our Baseline)
+
+```python
+from lettucedetect.models.inference import HallucinationDetector
+
+detector = HallucinationDetector(method="llm", lang="de")
+
+contexts = [
+ "Frankreich ist ein Land in Europa. Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 67 Millionen."
+]
+question = "Was ist die Hauptstadt von Frankreich? Wie hoch ist die Bevölkerung Frankreichs?"
+answer = "Die Hauptstadt von Frankreich ist Paris. Die Bevölkerung Frankreichs beträgt 82222 Millionen."
+
+predictions = detector.predict(context=contexts, question=question, answer=answer, output_format="spans")
+print("Predictions:", predictions)
+```
+
+
+## Background
+
+**LettuceDetect** ([blog](https://huggingface.co/blog/adaamko/lettucedetect)) is a lightweight, open-source hallucination detector for RAG pipelines that leverages ModernBERT for efficient token-level detection. It was originally trained on [RAGTruth](https://aclanthology.org/2024.acl-long.585/), a manually annotated English dataset for hallucination detection. The initial research demonstrated that encoder-based models can outperform large LLM judges while being significantly faster and more cost-effective.
+
+Despite these advances, many real-world RAG applications are multilingual, and detecting hallucinations across languages remains challenging due to the lack of specialized models and datasets for non-English content.
+
+
+## Our Approach
+
+To address this gap, we created multilingual versions of the RAGTruth dataset and fine-tuned EuroBERT for hallucination detection across languages. For translation, we used [google/gemma-3-27b-it](https://huggingface.co/google/gemma-3-27b-it) with [vllm](https://github.com/vllm-project/vllm) running on a single A100 GPU. This setup enabled parallel translation of approximately 30 examples at a time, with each language translation pass taking about 12 hours.
+
+Our pipeline works as follows:
+
+1. **Annotation Tagging**: In the English RAGTruth data, hallucinated answer spans are tagged using `` XML tags.
+ Example:
+ ```
+
+ The French Revolution started in 1788.
+
+ ```
+
+2. **LLM-based Translation**: We translate context, question, and answer while preserving all `` tags. For easier translation, we merge overlapping tags.
+
+3. **Extraction & Validation**: We extract the translated content and annotations, maintaining the same format as the original RAGTruth data.
+
+4. **Fine-tuning**: We train EuroBERT models for token classification to identify hallucinated content in each language.
+
+
+### Supported Languages
+
+Our models support hallucination detection in Chinese, French, German, Italian, Spanish, and Polish.
+
+### Translation Example
+
+To illustrate our approach, here's an example from the original RAGTruth data and its German translation:
+
+**English**
+```xml
+The first quartile (Q1) splits the lowest 25% of the data, while the second quartile (Q2) splits the data into two equal halves, with the median being the middle value of the lower half. Finally, the third quartile (Q3) splits the highest 75% of the data.
+```
+
+- *The phrase "highest 75%" is hallucinated, as the reference correctly states "lowest 75% (or highest 25%)".*
+
+**German**
+
+```xml
+Das erste Quartil (Q1) teilt die unteren 25% der Daten, während das zweite Quartil (Q2) die Daten in zwei gleiche Hälften teilt, wobei der Median den Mittelpunkt der unteren Hälfte bildet. Schließlich teilt das dritte Quartil (Q3) die höchsten 75% der Daten.
+```
+
+Here, the phrase "höchsten 75%" is hallucinated, as the reference correctly states "unteren 75% (oder höchsten 25%)".
+
+
+## Model Architecture
+
+We leverage [**EuroBERT**](https://huggingface.co/blog/EuroBERT/release), a recently released transformer model that represents a significant advancement in encoder architectures with its long-context capabilities and multilingual support.
+
+Trained on a massive 5 trillion-token corpus spanning 15 languages, EuroBERT processes sequences up to 8,192 tokens. The architecture incorporates modern innovations including grouped query attention, rotary positional embeddings, and advanced normalization techniques. These features enable both high computational efficiency and strong generalization abilities.
+
+For multilingual hallucination detection, we trained both the 210M and 610M parameter variants across all supported languages.
+
+## Training Process
+
+Our EuroBERT-based hallucination detection models follow the original LettuceDetect training methodology:
+
+**Input Processing**
+- Concatenate Context, Question, and Answer with special tokens
+- Cap sequences at 4,096 tokens for computational efficiency
+- Use AutoTokenizer for appropriate tokenization and segment markers
+
+**Label Assignment**
+- Mask context and question tokens (label = -100)
+- Assign binary labels to answer tokens: 0 (supported) or 1 (hallucinated)
+
+**Model Configuration**
+- Use EuroBERT within AutoModelForTokenClassification framework
+- Add only a linear classification head without additional pretraining
+
+**Training Details**
+- AdamW optimizer (learning rate = 1 × 10⁻⁵, weight decay = 0.01)
+- Six epochs with batch size of 8
+- Dynamic padding via DataCollatorForTokenClassification
+- Single NVIDIA A100 GPU (80GB) per language
+
+During inference, tokens with hallucination probabilities above 0.5 are merged into contiguous spans, providing precise identification of problematic content.
+
+## Results
+
+We evaluated our models on the translated RAGTruth dataset and compared them to a prompt-based baseline using GPT-4.1-mini. This baseline was implemented using few-shot prompting to identify hallucinated spans directly.
+
+### Synthetic Multilingual Results
+
+| Language | Model | Precision (%) | Recall (%) | F1 (%) | GPT-4.1-mini Precision (%) | GPT-4.1-mini Recall (%) | GPT-4.1-mini F1 (%) | Δ F1 (%) |
+|----------|-----------------|---------------|------------|--------|----------------------------|-------------------------|---------------------|----------|
+| Chinese | EuroBERT-210M | 75.46 | 73.38 | 74.41 | 43.97 | 95.55 | 60.23 | +14.18 |
+| Chinese | EuroBERT-610M | **78.90** | **75.72** | **77.27** | 43.97 | 95.55 | 60.23 | **+17.04** |
+| French | EuroBERT-210M | 58.86 | 74.34 | 65.70 | 46.45 | 94.91 | 62.37 | +3.33 |
+| French | EuroBERT-610M | **67.08** | **80.38** | **73.13** | 46.45 | 94.91 | 62.37 | **+10.76** |
+| German | EuroBERT-210M | 66.70 | 66.70 | 66.70 | 44.82 | 95.02 | 60.91 | +5.79 |
+| German | EuroBERT-610M | **77.04** | **72.96** | **74.95** | 44.82 | 95.02 | 60.91 | **+14.04** |
+| Italian | EuroBERT-210M | 60.57 | 72.32 | 65.93 | 44.87 | 95.55 | 61.06 | +4.87 |
+| Italian | EuroBERT-610M | **76.67** | **72.85** | **74.71** | 44.87 | 95.55 | 61.06 | **+13.65** |
+| Spanish | EuroBERT-210M | 69.48 | 73.38 | 71.38 | 46.56 | 94.59 | 62.40 | +8.98 |
+| Spanish | EuroBERT-610M | **76.32** | 70.41 | **73.25** | 46.56 | 94.59 | 62.40 | **+10.85** |
+| Polish | EuroBERT-210M | 63.62 | 69.57 | 66.46 | 42.92 | 95.76 | 59.27 | +7.19 |
+| Polish | EuroBERT-610M | **77.16** | 69.36 | **73.05** | 42.92 | 95.76 | 59.27 | **+13.78** |
+
+Across all languages, the EuroBERT-610M model consistently outperforms both the 210M variant and the GPT-4.1-mini baseline.
+
+### Manual Validation (German)
+
+For a more rigorous evaluation, we manually reviewed 300 examples covering all task types from RAGTruth (QA, summarization, data-to-text). After correcting any annotation errors, we found that performance remained strong, validating our translation approach:
+
+| Model | Precision (%) | Recall (%) | F1 (%) |
+|------------------|---------------|------------|--------|
+| EuroBERT-210M | 68.32 | 68.32 | 68.32 |
+| EuroBERT-610M | **74.47** | 69.31 | **71.79** |
+| GPT-4.1-mini | 44.50 | **92.08** | 60.00 |
+
+An interesting pattern: GPT-4.1-mini shows high recall but poor precision - it identifies most hallucinations but produces many false positives, making it less reliable in production settings.
+
+## Trade-offs: Model Size vs Performance
+
+When choosing between model variants, consider these trade-offs:
+
+- **EuroBERT-210M** – Approximately 3× faster inference, smaller memory footprint, 5-10% lower F1 scores
+- **EuroBERT-610M** – Highest detection accuracy across all languages, requires more compute resources
+
+## Key Takeaways
+
+- **Translating annotation can be effective**: Preserving hallucination tags through translation enables rapid creation of multilingual detection datasets when sufficient data is not available.
+- **EuroBERT is a good choice for multilingual hallucination detection**: Its long-context capabilities and efficient attention mechanisms make it ideal for RAG verification.
+- **Open framework for multilingual RAG**: All components are available under MIT license: translation, training, and inference.
+
+
+## Citation
+
+If you find this work useful, please cite it as follows:
+
+```bibtex
+@misc{Kovacs:2025,
+ title={LettuceDetect: A Hallucination Detection Framework for RAG Applications},
+ author={Ádám Kovács and Gábor Recski},
+ year={2025},
+ eprint={2502.17125},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ url={https://arxiv.org/abs/2502.17125},
+}
+```
+
+## References
+
+[1] [Niu et al., 2024, RAGTruth: A Dataset for Hallucination Detection in Retrieval-Augmented Generation](https://aclanthology.org/2024.acl-long.585/)
+
+[2] [Luna: A Simple and Effective Encoder-Based Model for Hallucination Detection in Retrieval-Augmented Generation](https://aclanthology.org/2025.coling-industry.34/)
+
+[3] [ModernBERT: A Modern BERT Model for Long-Context Processing](https://huggingface.co/blog/modernbert)
+
+[4] [Gemma 3](https://blog.google/technology/developers/gemma-3/)
+
+[5] [EuroBERT](https://huggingface.co/blog/EuroBERT/release)
\ No newline at end of file
diff --git a/lettucedetect/detectors/__init__.py b/lettucedetect/detectors/__init__.py
new file mode 100644
index 0000000..4b0df18
--- /dev/null
+++ b/lettucedetect/detectors/__init__.py
@@ -0,0 +1,13 @@
+from __future__ import annotations
+
+from lettucedetect.detectors.base import BaseDetector
+from lettucedetect.detectors.factory import make_detector as _make_detector
+from lettucedetect.detectors.llm import LLMDetector
+from lettucedetect.detectors.transformer import TransformerDetector
+
+__all__ = [
+ "BaseDetector",
+ "LLMDetector",
+ "TransformerDetector",
+ "_make_detector",
+]
diff --git a/lettucedetect/detectors/base.py b/lettucedetect/detectors/base.py
new file mode 100644
index 0000000..f9d20c6
--- /dev/null
+++ b/lettucedetect/detectors/base.py
@@ -0,0 +1,39 @@
+"""Abstract base class for hallucination detectors."""
+
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+
+
+class BaseDetector(ABC):
+ """All hallucination detectors implement a common interface."""
+
+ @abstractmethod
+ def predict(
+ self,
+ context: list[str],
+ answer: str,
+ question: str | None = None,
+ output_format: str = "tokens",
+ ) -> list:
+ """Predict hallucination tokens or spans given passages and an answer.
+
+ :param context: List of passages that were supplied to the LLM / user.
+ :param answer: Model‑generated answer to inspect.
+ :param question: Original question (``None`` for summarisation).
+ :param output_format: ``"tokens"`` for token‑level dicts, ``"spans"`` for character spans.
+ :returns: List of predictions in requested format.
+ """
+ pass
+
+ @abstractmethod
+ def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list:
+ """Predict hallucinations when you already have a *single* full prompt string."""
+ pass
+
+ @abstractmethod
+ def predict_prompt_batch(
+ self, prompts: list[str], answers: list[str], output_format: str = "tokens"
+ ) -> list:
+ """Batch version of `predict_prompt`."""
+ pass
diff --git a/lettucedetect/detectors/cache.py b/lettucedetect/detectors/cache.py
new file mode 100644
index 0000000..437b21e
--- /dev/null
+++ b/lettucedetect/detectors/cache.py
@@ -0,0 +1,34 @@
+"""Thread‑safe JSON cache with SHA‑256 keys."""
+
+from __future__ import annotations
+
+import hashlib
+import json
+import threading
+from pathlib import Path
+from typing import Any
+
+
+class CacheManager:
+ """Disk‑backed cache for expensive LLM calls."""
+
+ def __init__(self, file_path: str | Path):
+ self.path = Path(file_path)
+ self.path.parent.mkdir(parents=True, exist_ok=True)
+ self._lock = threading.Lock()
+ self._data: dict[str, Any] = (
+ json.loads(self.path.read_text("utf-8")) if self.path.exists() else {}
+ )
+
+ @staticmethod
+ def _hash(*parts: str) -> str:
+ return hashlib.sha256("||".join(parts).encode()).hexdigest()
+
+ def get(self, key: str) -> Any | None:
+ with self._lock:
+ return self._data.get(key)
+
+ def set(self, key: str, value: Any) -> None:
+ with self._lock:
+ self._data[key] = value
+ self.path.write_text(json.dumps(self._data, ensure_ascii=False), encoding="utf-8")
diff --git a/lettucedetect/detectors/factory.py b/lettucedetect/detectors/factory.py
new file mode 100644
index 0000000..430e616
--- /dev/null
+++ b/lettucedetect/detectors/factory.py
@@ -0,0 +1,27 @@
+"""Factory function for creating detector instances."""
+
+from __future__ import annotations
+
+from lettucedetect.detectors.base import BaseDetector
+
+__all__ = ["make_detector"]
+
+
+def make_detector(method: str, **kwargs) -> BaseDetector:
+ """Create a detector of the requested type with the given parameters.
+
+ :param method: One of "transformer" or "llm".
+ :param kwargs: Passed to the concrete detector constructor.
+ :return: A concrete detector instance.
+ :raises ValueError: If method is not one of "transformer" or "llm".
+ """
+ if method == "transformer":
+ from lettucedetect.detectors.transformer import TransformerDetector
+
+ return TransformerDetector(**kwargs)
+ elif method == "llm":
+ from lettucedetect.detectors.llm import LLMDetector
+
+ return LLMDetector(**kwargs)
+ else:
+ raise ValueError(f"Unknown detector method: {method}. Use one of: transformer, llm")
diff --git a/lettucedetect/detectors/llm.py b/lettucedetect/detectors/llm.py
new file mode 100644
index 0000000..c24208a
--- /dev/null
+++ b/lettucedetect/detectors/llm.py
@@ -0,0 +1,240 @@
+from __future__ import annotations
+
+import json
+import os
+import re
+from concurrent.futures import ThreadPoolExecutor
+from pathlib import Path
+from string import Template
+
+from openai import OpenAI
+
+from lettucedetect.detectors.cache import CacheManager
+from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils
+
+ANNOTATE_SCHEMA = [
+ {
+ "type": "function",
+ "function": {
+ "name": "annotate",
+ "description": "Return hallucinated substrings from the answer relative to the source.",
+ "parameters": {
+ "type": "object",
+ "properties": {
+ "hallucination_list": {
+ "type": "array",
+ "items": {"type": "string"},
+ }
+ },
+ "required": ["hallucination_list"],
+ },
+ },
+ }
+]
+
+
+class LLMDetector:
+ """LLM-powered hallucination detector."""
+
+ def __init__(
+ self,
+ model: str = "gpt-4.1-mini",
+ temperature: float = 0.0,
+ lang: Lang = "en",
+ zero_shot: bool = False,
+ fewshot_path: str | None = None,
+ prompt_path: str | None = None,
+ cache_file: str | None = None,
+ ):
+ """Initialize the LLMDetector.
+
+ :param model: The model to use for hallucination detection.
+ :param temperature: The temperature to use for hallucination detection.
+ :param lang: The language to use for hallucination detection.
+ :param zero_shot: Whether to use zero-shot hallucination detection.
+ :param fewshot_path: The path to the few-shot examples.
+ :param prompt_path: The path to the prompt.
+ :param cache_file: The path to the cache file.
+ """
+ if lang not in LANG_TO_PASSAGE:
+ raise ValueError(f"Invalid language. Use one of: {', '.join(LANG_TO_PASSAGE.keys())}")
+
+ self.model = model
+ self.temperature = temperature
+ self.lang = lang
+ self.zero_shot = zero_shot
+
+ # Load few-shot examples
+ if fewshot_path is None:
+ fewshot_path = (
+ Path(__file__).parent.parent / "prompts" / f"examples_{lang.lower()}.json"
+ )
+ path = Path(fewshot_path)
+ if not path.exists():
+ print(f"Warning: Few-shot examples file not found at {path}")
+ self.fewshot = json.loads(path.read_text(encoding="utf-8")) if path.exists() else []
+
+ # Load hallucination detection template
+ if prompt_path is None:
+ prompt_path = Path(__file__).parent.parent / "prompts" / "hallucination_detection.txt"
+ template_path = Path(prompt_path)
+ if not template_path.exists():
+ raise FileNotFoundError(f"Prompt template not found at {template_path}")
+ self.template = Template(template_path.read_text(encoding="utf-8"))
+
+ # Set up cache
+ if cache_file is None:
+ cache_file = (
+ Path(__file__).parent.parent
+ / "cache"
+ / f"cache_{model.replace(':', '_')}_{lang}.json"
+ )
+ print(f"Using default cache file: {cache_file}")
+ else:
+ print(f"Using provided cache file: {cache_file}")
+
+ self.cache = CacheManager(cache_file)
+
+ def _openai(self) -> OpenAI:
+ return OpenAI(
+ api_key=os.getenv("OPENAI_API_KEY") or "EMPTY",
+ base_url=os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1",
+ )
+
+ def _fewshot_block(self) -> str:
+ if self.zero_shot or not self.fewshot:
+ return ""
+ lines = []
+ for i, ex in enumerate(self.fewshot, 1):
+ lines.append(
+ f"""
+{ex["source"]}
+{ex["answer"]}
+{{"hallucination_list": {json.dumps(ex["hallucination_list"], ensure_ascii=False)} }}
+"""
+ )
+ return "\n".join(lines)
+
+ def _build_prompt(self, context: str, answer: str) -> str:
+ """Fill the template with runtime values, inserting few-shot examples.
+
+ :param context: The context string.
+ :param answer: The answer string.
+ :return: The filled template.
+ """
+ language_name = PromptUtils.get_full_language_name(self.lang)
+
+ return self.template.substitute(
+ lang=language_name,
+ context=context,
+ answer=answer,
+ fewshot_block=self._fewshot_block(),
+ )
+
+ @staticmethod
+ def _to_spans(substrs: list[str], answer: str) -> list[dict]:
+ """Convert a list of substrings to a list of spans.
+
+ :param substrs: List of substrings.
+ :param answer: The answer string.
+ :returns: List of spans.
+ """
+ spans = []
+ for sub in substrs:
+ if not sub:
+ continue
+ # Use regex for more reliable matching
+ match = re.search(re.escape(sub), answer)
+ if match:
+ spans.append({"start": match.start(), "end": match.end(), "text": sub})
+ return spans
+
+ def _predict(self, prompt: str, answer: str) -> list[dict]:
+ """Single (prompt, answer) pair → hallucination spans.
+
+ :param prompt: The prompt string.
+ :param answer: The answer string.
+ :returns: List of spans.
+ """
+ # Build the full LLM prompt using the template
+ llm_prompt = self._build_prompt(prompt, answer)
+
+ # Use the full LLM prompt for cache key calculation
+ cache_key = self.cache._hash(llm_prompt, self.model, str(self.temperature))
+
+ cached = self.cache.get(cache_key)
+ if cached is None:
+ resp = self._openai().chat.completions.create(
+ model=self.model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are an expert in detecting hallucinations in LLM outputs.",
+ },
+ # Use the full LLM prompt here, not the raw context
+ {"role": "user", "content": llm_prompt},
+ ],
+ tools=ANNOTATE_SCHEMA,
+ tool_choice={"type": "function", "function": {"name": "annotate"}},
+ temperature=self.temperature,
+ )
+ cached = resp.choices[0].message.tool_calls[0].function.arguments
+ self.cache.set(cache_key, cached)
+
+ try:
+ payload = json.loads(cached)
+ return self._to_spans(payload["hallucination_list"], answer)
+ except (json.JSONDecodeError, KeyError) as e:
+ print(f"Error parsing LLM response: {e}")
+ print(f"Raw response: {cached}")
+ return []
+
+ def predict(
+ self,
+ context: list[str],
+ answer: str,
+ question: str | None = None,
+ output_format: str = "spans",
+ ) -> list:
+ """Predict hallucination spans from the provided context, answer, and question.
+
+ :param context: List of passages that were supplied to the LLM / user.
+ :param answer: Model‑generated answer to inspect.
+ :param question: Original question (``None`` for summarisation).
+ :param output_format: ``"spans"`` for character spans.
+ :returns: List of spans.
+ """
+ if output_format != "spans":
+ raise ValueError("LLMDetector only supports 'spans' output_format.")
+ # Use PromptUtils to format the context and question
+ full_prompt = PromptUtils.format_context(context, question, self.lang)
+ return self._predict(full_prompt, answer)
+
+ def predict_prompt(self, prompt: str, answer: str, output_format: str = "spans") -> list:
+ """Predict hallucination spans from the provided prompt and answer.
+
+ :param prompt: The prompt string.
+ :param answer: The answer string.
+ :param output_format: ``"spans"`` for character spans.
+ :returns: List of spans.
+ """
+ if output_format != "spans":
+ raise ValueError("LLMDetector only supports 'spans' output_format.")
+ return self._predict(prompt, answer)
+
+ def predict_prompt_batch(
+ self, prompts: list[str], answers: list[str], output_format: str = "spans"
+ ) -> list:
+ """Predict hallucination spans from the provided prompts and answers.
+
+ :param prompts: List of prompt strings.
+ :param answers: List of answer strings.
+ :param output_format: ``"spans"`` for character spans.
+ :returns: List of spans.
+ """
+ if output_format != "spans":
+ raise ValueError("LLMDetector only supports 'spans' output_format.")
+
+ with ThreadPoolExecutor(max_workers=30) as pool:
+ futs = [pool.submit(self._predict, p, a) for p, a in zip(prompts, answers)]
+ return [f.result() for f in futs]
diff --git a/lettucedetect/detectors/prompt_utils.py b/lettucedetect/detectors/prompt_utils.py
new file mode 100644
index 0000000..cbda559
--- /dev/null
+++ b/lettucedetect/detectors/prompt_utils.py
@@ -0,0 +1,77 @@
+"""Utilities for loading and formatting prompts."""
+
+from __future__ import annotations
+
+from pathlib import Path
+from string import Template
+
+# Type for supported languages
+Lang = str # "en", "de", "fr", "es", "it", "pl", "cn"
+
+LANG_TO_PASSAGE = {
+ "en": "passage",
+ "de": "Passage",
+ "fr": "passage",
+ "es": "pasaje",
+ "it": "brano",
+ "pl": "fragment",
+ "cn": "段落",
+}
+
+# Full language names for each language code
+LANG_TO_FULL_NAME = {
+ "en": "English",
+ "de": "German",
+ "fr": "French",
+ "es": "Spanish",
+ "it": "Italian",
+ "pl": "Polish",
+ "cn": "Chinese",
+}
+
+PROMPT_DIR = Path(__file__).parent.parent / "prompts"
+
+
+class PromptUtils:
+ """Utility class for loading and formatting prompts."""
+
+ @staticmethod
+ def load_prompt(filename: str) -> Template:
+ """Load a prompt template from the prompts directory.
+
+ :param filename: Name of the prompt file
+ :return: Template object for the prompt
+ :raises FileNotFoundError: If the prompt file doesn't exist
+ """
+ path = PROMPT_DIR / filename
+ if not path.exists():
+ raise FileNotFoundError(f"Prompt file not found: {path}")
+ return Template(path.read_text(encoding="utf-8"))
+
+ @staticmethod
+ def format_context(context: list[str], question: str | None, lang: Lang) -> str:
+ """Format context and question into a prompt.
+
+ :param context: List of passages
+ :param question: Question (None for summarization tasks)
+ :param lang: Language code
+ :return: Formatted prompt
+ """
+ p_word = LANG_TO_PASSAGE[lang]
+ ctx_block = "\n".join(f"{p_word} {i + 1}: {p}" for i, p in enumerate(context))
+
+ if question is None:
+ tmpl = PromptUtils.load_prompt(f"summary_prompt_{lang.lower()}.txt")
+ return tmpl.substitute(text=ctx_block)
+
+ tmpl = PromptUtils.load_prompt(f"qa_prompt_{lang.lower()}.txt")
+ return tmpl.substitute(question=question, num_passages=len(context), context=ctx_block)
+
+ @staticmethod
+ def get_full_language_name(lang: Lang) -> str:
+ """Get the full language name for a language code.
+
+ :param lang: Language code
+ :return: Full language name
+ """
+ return LANG_TO_FULL_NAME.get(lang, "Unknown")
diff --git a/lettucedetect/detectors/transformer.py b/lettucedetect/detectors/transformer.py
new file mode 100644
index 0000000..e289744
--- /dev/null
+++ b/lettucedetect/detectors/transformer.py
@@ -0,0 +1,173 @@
+"""Transformer‑based hallucination detector."""
+
+from __future__ import annotations
+
+import torch
+from transformers import AutoModelForTokenClassification, AutoTokenizer
+
+from lettucedetect.datasets.hallucination_dataset import HallucinationDataset
+from lettucedetect.detectors.base import BaseDetector
+from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils
+
+__all__ = ["TransformerDetector"]
+
+
+class TransformerDetector(BaseDetector):
+ """Detect hallucinations with a fine‑tuned token classifier."""
+
+ def __init__(
+ self, model_path: str, max_length: int = 4096, device=None, lang: Lang = "en", **tok_kwargs
+ ):
+ """Initialize the transformer detector.
+
+ :param model_path: Path to the pre-trained model.
+ :param max_length: Maximum length of the input sequence.
+ :param device: Device to use for inference.
+ :param lang: Language of the model.
+ :param tok_kwargs: Additional keyword arguments for the tokenizer.
+ """
+ if lang not in LANG_TO_PASSAGE:
+ raise ValueError(f"Invalid language. Choose from {', '.join(LANG_TO_PASSAGE)}")
+ self.lang, self.max_length = lang, max_length
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, **tok_kwargs)
+ self.model = AutoModelForTokenClassification.from_pretrained(model_path, **tok_kwargs)
+ self.device = device or (
+ torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ )
+ self.model.to(self.device).eval()
+
+ def _predict(self, prompt: str, answer: str, output_format: str) -> list:
+ """Predict hallucination tokens or spans from the provided prompt and answer.
+
+ :param prompt: The prompt string.
+ :param answer: The answer string.
+ :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
+ """
+ # Use the shared tokenization logic from HallucinationDataset
+ encoding, _, offsets, answer_start_token = HallucinationDataset.prepare_tokenized_input(
+ self.tokenizer, prompt, answer, self.max_length
+ )
+
+ # Create a label tensor: mark tokens before answer as -100 (ignored) and answer tokens as 0.
+ labels = torch.full_like(encoding.input_ids[0], -100, device=self.device)
+ labels[answer_start_token:] = 0
+ # Move encoding to the device
+ encoding = {
+ key: value.to(self.device)
+ for key, value in encoding.items()
+ if key in ["input_ids", "attention_mask", "labels"]
+ }
+
+ # Run model inference
+ with torch.no_grad():
+ outputs = self.model(**encoding)
+ logits = outputs.logits
+ token_preds = torch.argmax(logits, dim=-1)[0]
+ probabilities = torch.softmax(logits, dim=-1)[0]
+
+ # Mask out predictions for context tokens.
+ token_preds = torch.where(labels == -100, labels, token_preds)
+
+ if output_format == "tokens":
+ # return token probabilities for each token (with the tokens as well, if not -100)
+ token_probs = []
+ input_ids = encoding["input_ids"][0] # Get the input_ids tensor from the encoding dict
+ for i, (token, pred, prob) in enumerate(zip(input_ids, token_preds, probabilities)):
+ if not labels[i].item() == -100:
+ token_probs.append(
+ {
+ "token": self.tokenizer.decode([token]),
+ "pred": pred.item(),
+ "prob": prob[1].item(), # Get probability for class 1 (hallucination)
+ }
+ )
+ return token_probs
+ elif output_format == "spans":
+ # Compute the answer's character offset (the first token of the answer).
+ if answer_start_token < offsets.size(0):
+ answer_char_offset = offsets[answer_start_token][0].item()
+ else:
+ answer_char_offset = 0
+
+ spans: list[dict] = []
+ current_span: dict | None = None
+
+ # Iterate over tokens in the answer region.
+ for i in range(answer_start_token, token_preds.size(0)):
+ # Skip tokens marked as ignored.
+ if labels[i].item() == -100:
+ continue
+
+ token_start, token_end = offsets[i].tolist()
+ # Skip special tokens with zero length.
+ if token_start == token_end:
+ continue
+
+ # Adjust offsets relative to the answer text.
+ rel_start = token_start - answer_char_offset
+ rel_end = token_end - answer_char_offset
+
+ is_hallucination = (
+ token_preds[i].item() == 1
+ ) # assuming class 1 indicates hallucination.
+ confidence = probabilities[i, 1].item() if is_hallucination else 0.0
+
+ if is_hallucination:
+ if current_span is None:
+ current_span = {
+ "start": rel_start,
+ "end": rel_end,
+ "confidence": confidence,
+ }
+ else:
+ # Extend the current span.
+ current_span["end"] = rel_end
+ current_span["confidence"] = max(current_span["confidence"], confidence)
+ else:
+ # If we were building a hallucination span, finalize it.
+ if current_span is not None:
+ # Extract the hallucinated text from the answer.
+ span_text = answer[current_span["start"] : current_span["end"]]
+ current_span["text"] = span_text
+ spans.append(current_span)
+ current_span = None
+
+ # Append any span still in progress.
+ if current_span is not None:
+ span_text = answer[current_span["start"] : current_span["end"]]
+ current_span["text"] = span_text
+ spans.append(current_span)
+
+ return spans
+ else:
+ raise ValueError("Invalid output_format. Use 'tokens' or 'spans'.")
+
+ def predict(self, context, answer, question=None, output_format="tokens") -> list:
+ """Predict hallucination tokens or spans from the provided context, answer, and question.
+
+ :param context: List of passages that were supplied to the LLM / user.
+ :param answer: Model‑generated answer to inspect.
+ :param question: Original question (``None`` for summarisation).
+ :param output_format: ``"tokens"`` for token‑level dicts, ``"spans"`` for character spans.
+ :returns: List of predictions in requested format.
+ """
+ formatted_prompt = PromptUtils.format_context(context, question, self.lang)
+ return self._predict(formatted_prompt, answer, output_format)
+
+ def predict_prompt(self, prompt, answer, output_format="tokens") -> list:
+ """Predict hallucination tokens or spans from the provided prompt and answer.
+
+ :param prompt: The prompt string.
+ :param answer: The answer string.
+ :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
+ """
+ return self._predict(prompt, answer, output_format)
+
+ def predict_prompt_batch(self, prompts, answers, output_format="tokens") -> list:
+ """Predict hallucination tokens or spans from the provided prompts and answers.
+
+ :param prompts: List of prompt strings.
+ :param answers: List of answer strings.
+ :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
+ """
+ return [self._predict(p, a, output_format) for p, a in zip(prompts, answers)]
diff --git a/lettucedetect/models/evaluator.py b/lettucedetect/models/evaluator.py
index 2b6b451..52e9c2c 100644
--- a/lettucedetect/models/evaluator.py
+++ b/lettucedetect/models/evaluator.py
@@ -276,6 +276,74 @@ def evaluate_detector_char_level(
return {"precision": precision, "recall": recall, "f1": f1}
+def evaluate_detector_example_level_batch(
+ detector: HallucinationDetector,
+ samples: list[HallucinationSample],
+ batch_size: int = 10,
+ verbose: bool = True,
+) -> dict[str, dict[str, float]]:
+ """Evaluate the HallucinationDetector at the example level.
+
+ This function assumes that each sample is a dictionary containing:
+ - "prompt": the prompt text.
+ - "answer": the answer text.
+ - "gold_spans": a list of dictionaries where each dictionary has "start" and "end" keys
+ indicating the character indices of the gold (human-labeled) span.
+
+ """
+ example_preds: list[int] = []
+ example_labels: list[int] = []
+
+ for i in tqdm(range(0, len(samples), batch_size), desc="Evaluating", leave=False):
+ batch = samples[i : i + batch_size]
+ prompts = [sample.prompt for sample in batch]
+ answers = [sample.answer for sample in batch]
+ predicted_spans = detector.predict_prompt_batch(prompts, answers, output_format="spans")
+
+ for sample, pred_spans in zip(batch, predicted_spans):
+ true_example_label = 1 if sample.labels else 0
+ pred_example_label = 1 if pred_spans else 0
+
+ example_labels.append(true_example_label)
+ example_preds.append(pred_example_label)
+
+ precision, recall, f1, _ = precision_recall_fscore_support(
+ example_labels, example_preds, labels=[0, 1], average=None, zero_division=0
+ )
+
+ results: dict[str, dict[str, float]] = {
+ "supported": { # Class 0
+ "precision": float(precision[0]),
+ "recall": float(recall[0]),
+ "f1": float(f1[0]),
+ },
+ "hallucinated": { # Class 1
+ "precision": float(precision[1]),
+ "recall": float(recall[1]),
+ "f1": float(f1[1]),
+ },
+ }
+
+ # Calculating AUROC
+ fpr, tpr, _ = roc_curve(example_labels, example_preds)
+ auroc = auc(fpr, tpr)
+ results["auroc"] = auroc
+
+ if verbose:
+ report = classification_report(
+ example_labels,
+ example_preds,
+ target_names=["Supported", "Hallucinated"],
+ digits=4,
+ zero_division=0,
+ )
+ print("\nDetailed Example-Level Classification Report:")
+ print(report)
+ results["classification_report"] = report
+
+ return results
+
+
def evaluate_detector_example_level(
detector: HallucinationDetector,
samples: list[HallucinationSample],
diff --git a/lettucedetect/models/inference.py b/lettucedetect/models/inference.py
index f638954..035ed3c 100644
--- a/lettucedetect/models/inference.py
+++ b/lettucedetect/models/inference.py
@@ -1,214 +1,23 @@
-import hashlib
-import json
-import os
-import re
-from abc import ABC, abstractmethod
-from pathlib import Path
-from string import Template
-from typing import Literal
+# lettucedetect/inference.py
+"""Public façade for LettuceDetect.
-import torch
-from openai import OpenAI
-from transformers import AutoModelForTokenClassification, AutoTokenizer
+Down-stream code should keep importing **HallucinationDetector** from here;
+the concrete detector classes now live in :pymod:`lettucedetect.detectors.*`.
+Nothing in the public API has changed.
+"""
-from lettucedetect.datasets.hallucination_dataset import (
- HallucinationDataset,
-)
+from lettucedetect.detectors.factory import make_detector
-LANG_TO_PASSAGE = {
- "en": "passage",
- "de": "Passage",
- "fr": "passage",
- "es": "pasaje",
- "it": "brano",
- "pl": "fragment",
-}
+class HallucinationDetector:
+ """Facade class that delegates to a concrete detector chosen by *method*.
-# ==== Base class for all detectors ====
-class BaseDetector(ABC):
- @abstractmethod
- def predict(self, context: str, answer: str, output_format: str = "tokens") -> list:
- """Given a context and an answer, returns predictions.
-
- :param context: The context string.
- :param answer: The answer string.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
- """
- pass
-
-
-# ==== Transformer-based detector ====
-class TransformerDetector(BaseDetector):
- def __init__(
- self,
- model_path: str,
- max_length: int = 4096,
- device=None,
- lang: Literal["en", "de", "fr", "es", "it", "pl"] = "en",
- **kwargs,
- ):
- """Initialize the TransformerDetector.
-
- :param model_path: The path to the model.
- :param max_length: The maximum length of the input sequence.
- :param device: The device to run the model on.
- :param lang: The language of the model.
- """
- if lang not in LANG_TO_PASSAGE:
- raise ValueError(f"Invalid language. Use one of: {', '.join(LANG_TO_PASSAGE.keys())}")
-
- self.lang = lang
- self.tokenizer = AutoTokenizer.from_pretrained(model_path, **kwargs)
- self.model = AutoModelForTokenClassification.from_pretrained(model_path, **kwargs)
- self.max_length = max_length
- self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
- self.model.to(self.device)
- self.model.eval()
-
- prompt_path = Path(__file__).parent.parent / "prompts" / f"qa_prompt_{lang.lower()}.txt"
- self.prompt_qa = Template(prompt_path.read_text(encoding="utf-8"))
- prompt_path = (
- Path(__file__).parent.parent / "prompts" / f"summary_prompt_{lang.lower()}.txt"
- )
- self.prompt_summary = Template(prompt_path.read_text(encoding="utf-8"))
-
- def _form_prompt(self, context: list[str], question: str | None) -> str:
- """Form a prompt from the provided context and question. We use different prompts for summary and QA tasks.
-
- :param context: A list of context strings.
- :param question: The question string.
- :return: The formatted prompt.
- """
- context_str = "\n".join(
- [
- f"{LANG_TO_PASSAGE[self.lang]} {i + 1}: {passage}"
- for i, passage in enumerate(context)
- ]
- )
- if question is None:
- return self.prompt_summary.substitute(text=context_str)
- else:
- return self.prompt_qa.substitute(
- question=question, num_passages=len(context), context=context_str
- )
-
- def _predict(self, context: str, answer: str, output_format: str = "tokens") -> list:
- """Predict hallucination tokens or spans from the provided context and answer.
-
- :param context: The context string.
- :param answer: The answer string.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
- """
- # Use the shared tokenization logic from RagTruthDataset
- encoding, labels, offsets, answer_start_token = (
- HallucinationDataset.prepare_tokenized_input(
- self.tokenizer, context, answer, self.max_length
- )
- )
-
- # Create a label tensor: mark tokens before answer as -100 (ignored) and answer tokens as 0.
- labels = torch.full_like(encoding.input_ids[0], -100, device=self.device)
- labels[answer_start_token:] = 0
- # Move encoding to the device
- encoding = {
- key: value.to(self.device)
- for key, value in encoding.items()
- if key in ["input_ids", "attention_mask", "labels"]
- }
-
- # Run model inference
- with torch.no_grad():
- outputs = self.model(**encoding)
- logits = outputs.logits
- token_preds = torch.argmax(logits, dim=-1)[0]
- probabilities = torch.softmax(logits, dim=-1)[0]
-
- # Mask out predictions for context tokens.
- token_preds = torch.where(labels == -100, labels, token_preds)
-
- if output_format == "tokens":
- # return token probabilities for each token (with the tokens as well, if not -100)
- token_probs = []
- input_ids = encoding["input_ids"][0] # Get the input_ids tensor from the encoding dict
- for i, (token, pred, prob) in enumerate(zip(input_ids, token_preds, probabilities)):
- if not labels[i].item() == -100:
- token_probs.append(
- {
- "token": self.tokenizer.decode([token]),
- "pred": pred.item(),
- "prob": prob[1].item(), # Get probability for class 1 (hallucination)
- }
- )
- return token_probs
- elif output_format == "spans":
- # Compute the answer's character offset (the first token of the answer).
- if answer_start_token < offsets.size(0):
- answer_char_offset = offsets[answer_start_token][0].item()
- else:
- answer_char_offset = 0
-
- spans: list[dict] = []
- current_span: dict | None = None
-
- # Iterate over tokens in the answer region.
- for i in range(answer_start_token, token_preds.size(0)):
- # Skip tokens marked as ignored.
- if labels[i].item() == -100:
- continue
-
- token_start, token_end = offsets[i].tolist()
- # Skip special tokens with zero length.
- if token_start == token_end:
- continue
-
- # Adjust offsets relative to the answer text.
- rel_start = token_start - answer_char_offset
- rel_end = token_end - answer_char_offset
-
- is_hallucination = (
- token_preds[i].item() == 1
- ) # assuming class 1 indicates hallucination.
- confidence = probabilities[i, 1].item() if is_hallucination else 0.0
-
- if is_hallucination:
- if current_span is None:
- current_span = {
- "start": rel_start,
- "end": rel_end,
- "confidence": confidence,
- }
- else:
- # Extend the current span.
- current_span["end"] = rel_end
- current_span["confidence"] = max(current_span["confidence"], confidence)
- else:
- # If we were building a hallucination span, finalize it.
- if current_span is not None:
- # Extract the hallucinated text from the answer.
- span_text = answer[current_span["start"] : current_span["end"]]
- current_span["text"] = span_text
- spans.append(current_span)
- current_span = None
-
- # Append any span still in progress.
- if current_span is not None:
- span_text = answer[current_span["start"] : current_span["end"]]
- current_span["text"] = span_text
- spans.append(current_span)
-
- return spans
- else:
- raise ValueError("Invalid output_format. Use 'tokens' or 'spans'.")
-
- def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list:
- """Predict hallucination tokens or spans from the provided prompt and answer.
+ :param method: ``"transformer"`` (token-classifier) or ``"llm"`` (OpenAI function-calling).
+ :param kwargs: Passed straight through to the chosen detector’s constructor.
+ """
- :param prompt: The prompt string.
- :param answer: The answer string.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
- """
- return self._predict(prompt, answer, output_format)
+ def __init__(self, method: str = "transformer", **kwargs):
+ self.detector = make_detector(method, **kwargs)
def predict(
self,
@@ -217,315 +26,31 @@ def predict(
question: str | None = None,
output_format: str = "tokens",
) -> list:
- """Predict hallucination tokens or spans from the provided context, answer, and question.
- This is a useful interface when we don't want to predict a specific prompt, but rather we have a list of contexts, answers, and questions. Useful to interface with RAG systems.
-
- :param context: A list of context strings.
- :param answer: The answer string.
- :param question: The question string.
- :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
- """
- prompt = self._form_prompt(context, question)
- return self._predict(prompt, answer, output_format)
-
-
-# ==== LLM-based detector ====
-ANNOTATE_SCHEMA = [
- {
- "type": "function",
- "function": {
- "name": "annotate",
- "description": "Return hallucinated substrings from the answer relative to the source.",
- "parameters": {
- "type": "object",
- "properties": {
- "hallucination_list": {
- "type": "array",
- "items": {"type": "string"},
- }
- },
- "required": ["hallucination_list"],
- },
- },
- }
-]
-
+ """Predict hallucination tokens or spans given passages and an answer.
-class LLMDetector(BaseDetector):
- """LLM-powered hallucination detector using function calling and a prompt template."""
+ This is the call most RAG pipelines use.
- def __init__(
- self,
- model: str = "gpt-4o",
- temperature: int = 0,
- lang: Literal["en", "de", "fr", "es", "it", "pl"] = "en",
- zero_shot: bool = False,
- fewshot_path: str | None = None,
- prompt_path: str | None = None,
- cache_file: str | None = None,
- ):
- """Initialize the LLMDetector.
-
- :param model: OpenAI model.
- :param temperature: model temperature.
- :param lang: language of the examples.
- :param zero_shot: whether to use zero-shot prompting.
- :param fewshot_path: path to the fewshot examples.
- :param prompt_path: path to the prompt template.
- :param cache_file: path to the cache file.
+ See the concrete detector docs for the structure of the returned list.
"""
- self.model = model
- self.temperature = temperature
-
- if lang not in LANG_TO_PASSAGE:
- raise ValueError(f"Invalid language. Use one of: {', '.join(LANG_TO_PASSAGE.keys())}")
-
- self.lang = lang
- self.zero_shot = zero_shot
- if fewshot_path is None:
- print(
- f"No fewshot path provided, using default path: {Path(__file__).parent.parent / 'prompts' / f'examples_{lang.lower()}.json'}"
- )
- fewshot_path = (
- Path(__file__).parent.parent / "prompts" / f"examples_{lang.lower()}.json"
- )
-
- if not fewshot_path.exists():
- raise FileNotFoundError(f"Fewshot file not found at {fewshot_path}")
- else:
- fewshot_path = Path(fewshot_path)
-
- if prompt_path is None:
- print(
- f"No prompt path provided, using default path: {Path(__file__).parent.parent / 'prompts' / 'hallucination_detection.txt'}"
- )
- template_path = Path(__file__).parent.parent / "prompts" / "hallucination_detection.txt"
- else:
- template_path = Path(prompt_path)
-
- prompt_qa_path = Path(__file__).parent.parent / "prompts" / f"qa_prompt_{lang.lower()}.txt"
- prompt_summary_path = (
- Path(__file__).parent.parent / "prompts" / f"summary_prompt_{lang.lower()}.txt"
- )
-
- self.template = Template(template_path.read_text(encoding="utf-8"))
- self.prompt_qa = Template(prompt_qa_path.read_text(encoding="utf-8"))
- self.prompt_summary = Template(prompt_summary_path.read_text(encoding="utf-8"))
-
- self.fewshot = json.loads(fewshot_path.read_text(encoding="utf-8"))
- self.cache_path = cache_file
-
- if cache_file is None:
- self.cache_path = (
- Path(__file__).parent.parent / "cache" / f"cache_{self.model}_{self.lang}.json"
- )
- self.cache_path.parent.mkdir(parents=True, exist_ok=True)
-
- # Read in cache
- if self.cache_path.exists():
- self.cache = json.loads(self.cache_path.read_text(encoding="utf-8"))
- else:
- self.cache = {}
-
- print(f"Cache file not provided, using default path: {self.cache_path}")
- else:
- self.cache_path = Path(cache_file)
- if not self.cache_path.exists():
- raise FileNotFoundError(f"Cache file not found at {self.cache_path}")
- self.cache = json.loads(self.cache_path.read_text(encoding="utf-8"))
-
- def _build_prompt(
- self,
- context: str,
- answer: str,
- ) -> str:
- """Fill the template with runtime values, inserting zero or many few‑shot examples.
- Uses `${placeholder}` tokens in the .txt file.
- """
- fewshot_block = ""
- if self.fewshot and not self.zero_shot:
- lines: list[str] = []
- for idx, ex in enumerate(self.fewshot, 1):
- lines.append(
- f"""
-{ex["source"]}
-{ex["answer"]}
-{{"hallucination_list": {json.dumps(ex["hallucination_list"], ensure_ascii=False)} }}
-"""
- )
- fewshot_block = "\n".join(lines)
-
- filled = self.template.substitute(
- lang=self.lang,
- context=context,
- answer=answer,
- fewshot_block=fewshot_block,
- )
- return filled
-
- def _form_context(self, context: list[str], question: str | None) -> str:
- """Form a prompt from the provided context and question. We use different prompts for summary and QA tasks.
- :param context: A list of context strings.
- :param question: The question string.
- :return: The formatted prompt.
- """
- context_str = "\n".join(
- [f"passage {i + 1}: {passage}" for i, passage in enumerate(context)]
- )
- if question is None:
- return self.prompt_summary.substitute(text=context_str)
- else:
- return self.prompt_qa.substitute(
- question=question, num_passages=len(context), context=context_str
- )
-
- def _get_openai_client(self) -> OpenAI:
- """Get OpenAI client configured from environment variables.
-
- :return: Configured OpenAI client
- :raises ValueError: If API key is not set
- """
- api_key = os.getenv("OPENAI_API_KEY") or "EMPTY"
- api_base = os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1"
-
- return OpenAI(
- api_key=api_key,
- base_url=api_base,
- )
-
- def _hash(self, prompt: str) -> str:
- """Hash the prompt."""
- return hashlib.sha256(prompt.encode("utf-8")).hexdigest()
-
- def _call_openai(self, prompt: str) -> str:
- """Call the OpenAI API.
-
- :param prompt: The prompt to call the OpenAI API with.
- :return: The response from the OpenAI API.
- """
- client = self._get_openai_client()
- resp = client.chat.completions.create(
- model=self.model,
- messages=[
- {
- "role": "system",
- "content": "You are an expert in detecting hallucinations in LLM outputs.",
- },
- {"role": "user", "content": prompt},
- ],
- tools=ANNOTATE_SCHEMA,
- tool_choice={"type": "function", "function": {"name": "annotate"}},
- temperature=self.temperature,
- )
-
- return resp.choices[0].message.tool_calls[0].function.arguments
-
- def _save_cache(self):
- """Save the cache to the cache file."""
- self.cache_path.write_text(json.dumps(self.cache, ensure_ascii=False), encoding="utf-8")
-
- def _to_spans(self, subs: list[str], answer: str) -> list[dict]:
- """Convert a list of substrings to a list of spans.
-
- :param subs: A list of substrings.
- :param answer: The answer string.
- :return: A list of spans.
- """
- spans = []
- for s in subs:
- m = re.search(re.escape(s), answer)
- if m:
- spans.append({"start": m.start(), "end": m.end(), "text": s})
- return spans
-
- def _predict(self, context: str, answer: str, output_format: str = "spans") -> list:
- """Prompts the ChatGPT model to predict hallucination spans from the provided context and answer.
-
- :param context: The context string.
- :param answer: The answer string.
- :param output_format: works only for "spans" and returns grouped spans.
- """
- if output_format == "spans":
- llm_prompt = self._build_prompt(context, answer)
-
- key = self._hash("||".join([llm_prompt, self.model, str(self.temperature)]))
-
- # Check if the response is cached
- cached_response = self.cache.get(key)
- if cached_response is None:
- cached_response = self._call_openai(llm_prompt)
- self.cache[key] = cached_response
- self._save_cache()
-
- payload = json.loads(cached_response)
- return self._to_spans(payload["hallucination_list"], answer)
- else:
- raise ValueError(
- "Invalid output_format. This model can only predict hallucination spans. Use spans."
- )
+ return self.detector.predict(context, answer, question, output_format)
def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list:
- """Predict hallucination spans from the provided prompt and answer.
+ """Predict hallucinations when you already have a *single* full prompt string.
:param prompt: The prompt string.
:param answer: The answer string.
- :param output_format: "spans" to return grouped spans.
- """
- return self._predict(prompt, answer, output_format)
-
- def predict(
- self,
- context: list[str],
- answer: str,
- question: str | None = None,
- output_format: str = "spans",
- ) -> list:
- """Predict hallucination spans from the provided context, answer, and question.
- This is a useful interface when we don't want to predict a specific prompt, but rather we have a list of contexts, answers, and questions. Useful to interface with RAG systems.
-
- :param context: A list of context strings.
- :param answer: The answer string.
- :param question: The question string.
- :param output_format: "spans" to return grouped spans.
- """
- prompt = self._form_context(context, question)
- return self._predict(prompt, answer, output_format=output_format)
-
-
-class HallucinationDetector:
- def __init__(self, method: str = "transformer", **kwargs):
- """Facade for the hallucination detector.
-
- :param method: "transformer" for the model-based approach.
- :param kwargs: Additional keyword arguments passed to the underlying detector.
+ :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
"""
- if method == "transformer":
- self.detector = TransformerDetector(**kwargs)
- elif method == "llm":
- self.detector = LLMDetector(**kwargs)
- else:
- raise ValueError("Unsupported method. Choose 'transformer' or 'llm'.")
+ return self.detector.predict_prompt(prompt, answer, output_format)
- def predict(
- self,
- context: list[str],
- answer: str,
- question: str | None = None,
- output_format: str = "tokens",
+ def predict_prompt_batch(
+ self, prompts: list[str], answers: list[str], output_format: str = "tokens"
) -> list:
- """Predict hallucination tokens or spans from the provided context, answer, and question.
- This is a useful interface when we don't want to predict a specific prompt, but rather we have a list of contexts, answers, and questions. Useful to interface with RAG systems.
+ """Batch version of :py:meth:`predict_prompt`.
+ Length of *prompts* and *answers* must match.
- :param context: A list of context strings.
- :param answer: The answer string.
- :param question: The question string.
- """
- return self.detector.predict(context, answer, question, output_format)
-
- def predict_prompt(self, prompt: str, answer: str, output_format: str = "tokens") -> list:
- """Predict hallucination tokens or spans from the provided prompt and answer.
-
- :param prompt: The prompt string.
- :param answer: The answer string.
+ :param prompts: List of prompt strings.
+ :param answers: List of answer strings.
+ :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans.
"""
- return self.detector.predict_prompt(prompt, answer, output_format)
+ return self.detector.predict_prompt_batch(prompts, answers, output_format)
diff --git a/lettucedetect/prompts/examples_cn.json b/lettucedetect/prompts/examples_cn.json
new file mode 100644
index 0000000..a15cd54
--- /dev/null
+++ b/lettucedetect/prompts/examples_cn.json
@@ -0,0 +1,35 @@
+[
+ {
+ "source": "法国的首都是什么?法国的人口是多少?法国是欧洲的一个国家。法国的首都是巴黎。法国的人口是6700万。",
+ "answer": "法国的首都是巴黎。法国的人口是6900万。",
+ "hallucination_list": [
+ "法国的人口是6900万。"
+ ]
+ },
+ {
+ "source": "法国的首都是什么?法国的人口是多少?法国的首都是巴黎。法国的人口是6700万。",
+ "answer": "法国的首都是巴黎。法国的人口是6700万,官方语言是西班牙语。",
+ "hallucination_list": [
+ "官方语言是西班牙语"
+ ]
+ },
+ {
+ "source": "奥地利的首都是什么?奥地利的人口是多少?奥地利是欧洲的一个国家。奥地利的首都是维也纳。奥地利的人口是910万。",
+ "answer": "奥地利的首都是维也纳。奥地利的人口是910万。",
+ "hallucination_list": []
+ },
+ {
+ "source": "谁是第一个在月球上行走的人?这是什么时候发生的?尼尔·阿姆斯特朗是第一个在月球上行走的人。这一历史性事件发生在1969年7月20日,在阿波罗11号任务期间。",
+ "answer": "尼尔·阿姆斯特朗是第一个在月球上行走的人。这一历史性事件发生在1969年7月16日,在阿波罗11号任务期间。",
+ "hallucination_list": [
+ "这一历史性事件发生在1969年7月16日"
+ ]
+ },
+ {
+ "source": "世界上最高的山是什么?珠穆朗玛峰是世界上最高的山,海拔8,848米。它位于喜马拉雅山脉的马哈朗古尔山脉中,在尼泊尔和西藏边境上。",
+ "answer": "珠穆朗玛峰是世界上最高的山,海拔8,848米。它位于喜马拉雅山脉的马哈朗古尔山脉中,在印度和中国边境上。",
+ "hallucination_list": [
+ "它位于喜马拉雅山脉的马哈朗古尔山脉中,在印度和中国边境上。"
+ ]
+ }
+]
diff --git a/lettucedetect/prompts/qa_prompt_cn.txt b/lettucedetect/prompts/qa_prompt_cn.txt
new file mode 100644
index 0000000..49ac532
--- /dev/null
+++ b/lettucedetect/prompts/qa_prompt_cn.txt
@@ -0,0 +1,6 @@
+请简短回答以下问题:
+${question}
+请注意,您的回答应严格基于以下${num_passages}段落:
+${context}
+如果段落中不包含回答问题所需的信息,请回复:"基于给定段落无法回答。"
+输出:
diff --git a/lettucedetect/prompts/qa_prompt_pl.txt b/lettucedetect/prompts/qa_prompt_pl.txt
index df7fe4c..0f8e4b5 100644
--- a/lettucedetect/prompts/qa_prompt_pl.txt
+++ b/lettucedetect/prompts/qa_prompt_pl.txt
@@ -1,6 +1,6 @@
-Odpowiedz krótko na następujące pytanie:
+Zwięźle odpowiedz na następujące pytanie:
${question}
Pamiętaj, że Twoja odpowiedź powinna opierać się wyłącznie na następujących ${num_passages} fragmentach:
${context}
Jeśli fragmenty nie zawierają informacji niezbędnych do udzielenia odpowiedzi na pytanie, odpowiedz: "Nie można odpowiedzieć na podstawie podanych fragmentów."
-Wynik:
\ No newline at end of file
+output:
\ No newline at end of file
diff --git a/lettucedetect/prompts/summary_prompt_cn.txt b/lettucedetect/prompts/summary_prompt_cn.txt
new file mode 100644
index 0000000..741fb3e
--- /dev/null
+++ b/lettucedetect/prompts/summary_prompt_cn.txt
@@ -0,0 +1,3 @@
+请总结以下文本:
+${text}
+输出:
diff --git a/pyproject.toml b/pyproject.toml
index 6daa51d..8995fb5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "lettucedetect"
-version = "0.1.6"
+version = "0.1.7"
description = "Lettucedetect is a framework for detecting hallucinations in RAG applications."
readme = {file = "README.md", content-type = "text/markdown"}
requires-python = ">=3.10"
diff --git a/scripts/evaluate_llm.py b/scripts/evaluate_llm.py
index cba492f..ba0d0f1 100644
--- a/scripts/evaluate_llm.py
+++ b/scripts/evaluate_llm.py
@@ -8,7 +8,7 @@
)
from lettucedetect.models.evaluator import (
evaluate_detector_char_level,
- evaluate_detector_example_level,
+ evaluate_detector_example_level_batch,
print_metrics,
)
from lettucedetect.models.inference import HallucinationDetector
@@ -28,7 +28,7 @@ def evaluate_task_samples_llm(
if evaluation_type == "example_level":
print("\n---- Example-Level Span Evaluation ----")
- metrics = evaluate_detector_example_level(detector, samples)
+ metrics = evaluate_detector_example_level_batch(detector, samples)
print_metrics(metrics)
return metrics
elif evaluation_type == "char_level":
diff --git a/scripts/upload.py b/scripts/upload.py
index d771d61..1e84a70 100644
--- a/scripts/upload.py
+++ b/scripts/upload.py
@@ -27,8 +27,8 @@ def main():
args = parser.parse_args()
print(f"Loading model and tokenizer from {args.model_path} ...")
- model = AutoModelForTokenClassification.from_pretrained(args.model_path)
- tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ model = AutoModelForTokenClassification.from_pretrained(args.model_path, trust_remote_code=True)
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
print(f"Uploading model to Hugging Face Hub at repo: {args.repo_id} ...")
model.push_to_hub(args.repo_id, use_auth_token=args.use_auth_token)
diff --git a/tests/test_inference_pytest.py b/tests/test_inference_pytest.py
index b356ce3..546852b 100644
--- a/tests/test_inference_pytest.py
+++ b/tests/test_inference_pytest.py
@@ -5,7 +5,9 @@
import pytest
import torch
-from lettucedetect.models.inference import HallucinationDetector, TransformerDetector
+from lettucedetect.detectors.prompt_utils import PromptUtils
+from lettucedetect.detectors.transformer import TransformerDetector
+from lettucedetect.models.inference import HallucinationDetector
@pytest.fixture
@@ -31,7 +33,7 @@ class TestHallucinationDetector:
def test_init_with_transformer_method(self):
"""Test initialization with transformer method."""
- with patch("lettucedetect.models.inference.TransformerDetector") as mock_transformer:
+ with patch("lettucedetect.detectors.transformer.TransformerDetector") as mock_transformer:
detector = HallucinationDetector(method="transformer", model_path="dummy_path")
mock_transformer.assert_called_once_with(model_path="dummy_path")
assert isinstance(detector.detector, MagicMock)
@@ -48,7 +50,7 @@ def test_predict(self):
mock_detector.predict.return_value = []
with patch(
- "lettucedetect.models.inference.TransformerDetector", return_value=mock_detector
+ "lettucedetect.detectors.transformer.TransformerDetector", return_value=mock_detector
):
detector = HallucinationDetector(method="transformer")
context = ["This is a test context."]
@@ -72,7 +74,7 @@ def test_predict_prompt(self):
mock_detector.predict_prompt.return_value = []
with patch(
- "lettucedetect.models.inference.TransformerDetector", return_value=mock_detector
+ "lettucedetect.detectors.transformer.TransformerDetector", return_value=mock_detector
):
detector = HallucinationDetector(method="transformer")
prompt = "This is a test prompt."
@@ -99,11 +101,11 @@ def setup(self, mock_tokenizer, mock_model):
# Patch the AutoTokenizer and AutoModelForTokenClassification
self.tokenizer_patcher = patch(
- "lettucedetect.models.inference.AutoTokenizer.from_pretrained",
+ "lettucedetect.detectors.transformer.AutoTokenizer.from_pretrained",
return_value=self.mock_tokenizer,
)
self.model_patcher = patch(
- "lettucedetect.models.inference.AutoModelForTokenClassification.from_pretrained",
+ "lettucedetect.detectors.transformer.AutoModelForTokenClassification.from_pretrained",
return_value=self.mock_model,
)
@@ -156,7 +158,7 @@ def test_form_prompt_with_question(self):
context = ["This is passage 1.", "This is passage 2."]
question = "What is the test?"
- prompt = detector._form_prompt(context, question)
+ prompt = PromptUtils.format_context(context, question, "en")
# Check that the prompt contains the question and passages
assert question in prompt
@@ -168,7 +170,7 @@ def test_form_prompt_without_question(self):
detector = TransformerDetector(model_path="dummy_path")
context = ["This is a text to summarize."]
- prompt = detector._form_prompt(context, None)
+ prompt = PromptUtils.format_context(context, None, "en")
# Check that the prompt contains the text to summarize
assert "This is a text to summarize." in prompt