diff --git a/README.md b/README.md index bda443a..4d413ab 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,11 @@ Our models are inspired from the [Luna](https://aclanthology.org/2025.coling-ind ## 🚀 Latest Updates -- **May 18, 2025** - Released version **0.1.7**: Multilingual support (thanks to EuroBERT) for 7 languages: English, German, French, Spanish, Italian, Polish, and Chinese! +- **August 31, 2025** - Released version **0.1.8**: Added TinyLettuce Ettin models for 17M, 32M, and 68M variants, Hallucination generation pipeline and added RAGFactChecker for triplet-based hallucination detection. + - See [TinyLettuce Blog Post](https://huggingface.co/KRLabsOrg/tinylettuce-68b42a66b8b6aaa4bf287bf4) for more details. + - Our collection on Hugging Face: [TinyLettuce](https://huggingface.co/collections/KRLabsOrg/tinylettuce-68b42a66b8b6aaa4bf287bf4) + - See the documentation: [TinyLettuce Documentation](docs/TINYLETTUCE.md) for more details. +- May 18, 2025 - Released version **0.1.7**: Multilingual support (thanks to EuroBERT) for 7 languages: English, German, French, Spanish, Italian, Polish, and Chinese! - Up to **17 F1 points improvement** over baseline LLM judges like GPT-4.1-mini across different languages - **EuroBERT models**: We've trained base/210M (faster) and large/610M (more accurate) variants - You can now also use **LLM baselines** for hallucination detection (see below) @@ -60,8 +64,8 @@ pip install lettucedetect -U Check out our models published to Huggingface: **English Models**: -- Base: [KRLabsOrg/lettucedetect-base-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-base-modernbert-en-v1) -- Large: [KRLabsOrg/lettucedetect-large-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedetect-large-modernbert-en-v1) +- Base: [KRLabsOrg/lettucedect-base-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedect-base-modernbert-en-v1) +- Large: [KRLabsOrg/lettucedect-large-modernbert-en-v1](https://huggingface.co/KRLabsOrg/lettucedect-large-modernbert-en-v1) **Multilingual Models**: We've trained 210m and 610m variants of EuroBERT, see our HuggingFace collection: [HF models](https://huggingface.co/collections/KRLabsOrg/multilingual-hallucination-detection-682a2549c18ecd32689231ce) @@ -266,7 +270,7 @@ positional arguments: options: -h, --help show this help message and exit --model MODEL Path or huggingface URL to the model. The default value is - "KRLabsOrg/lettucedetect-base-modernbert-en-v1". + "KRLabsOrg/lettucedect-base-modernbert-en-v1". --method {transformer} Hallucination detection method. The default value is "transformer". diff --git a/assets/tinylettuce.jpeg b/assets/tinylettuce.jpeg new file mode 100644 index 0000000..9c04351 Binary files /dev/null and b/assets/tinylettuce.jpeg differ diff --git a/assets/tinytinylettuce.png b/assets/tinytinylettuce.png new file mode 100644 index 0000000..037b21f Binary files /dev/null and b/assets/tinytinylettuce.png differ diff --git a/demo/tinylettuce.ipynb b/demo/tinylettuce.ipynb new file mode 100644 index 0000000..cd95f7d --- /dev/null +++ b/demo/tinylettuce.ipynb @@ -0,0 +1,1020 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b9ea123a", + "metadata": {}, + "source": [ + "## 🥬 TinyLettuce: Efficient Hallucination Detection Small Models (Using Synthetic Data Generation)\n", + "\n", + "

\n", + " \"TinyLettuce\n", + "
\n", + " Small, task‑specialized encoders trained on synthetic data\n", + "

\n", + "\n", + "\n", + "[![LettuceDetect](https://img.shields.io/badge/LettuceDetect-v0.1.8-green)](https://github.com/your-username/LettuceDetect)\n", + "[![Python](https://img.shields.io/badge/Python-3.11+-blue)](https://python.org)\n", + "[![License](https://img.shields.io/badge/License-MIT-yellow)](https://opensource.org/licenses/MIT)\n", + "\n", + "## 🎯 Overview\n", + "\n", + "**The Problem**: Training robust hallucination detection models requires large datasets of both correct and hallucinated responses. Manually creating such datasets is expensive and time-consuming.\n", + "\n", + "**Our Solution**: LettuceDetect's synthetic data generation pipeline can generate realistic hallucinations from factual content.\n", + "\n", + "### What This Notebook Demonstrates\n", + "\n", + "1. **Answer-based Generation**: Inject specific error types into correct answers\n", + "2. **Batch Processing**: Efficient async generation for large datasets\n", + "3. **Training Integration**: Convert to formats ready for model training\n", + "\n", + "### Key Benefits\n", + "\n", + "- **Cost-effective**: Generate thousands of training samples at a fraction of manual annotation cost\n", + "- **Controllable**: Specify exact error types and intensity levels\n", + "- **Scalable**: Async batch processing for large scale datasets" + ] + }, + { + "cell_type": "markdown", + "id": "0086e655", + "metadata": {}, + "source": [ + "### Setup\n", + "\n", + "Install LettuceDetect:\n", + "```bash\n", + "pip install lettucedetect\n", + "```\n", + "\n", + "Then, install datasets and rich:\n", + "```bash\n", + "pip install datasets\n", + "pip install rich\n", + "```\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "95eac334", + "metadata": {}, + "outputs": [], + "source": [ + "# We recommend setting your OpenAI API key as an environment variable\n", + "# os.environ['OPENAI_API_KEY'] = 'your-api-key-here'" + ] + }, + { + "cell_type": "markdown", + "id": "eedf0e53", + "metadata": {}, + "source": [ + "### Generate Synthetic Data" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e5fb60b3", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the generator\n", + "from lettucedetect import HallucinationGenerator\n", + "\n", + "# The heart of the synthetic data generation pipeline is the HallucinationGenerator class\n", + "# GPT 5 requires temperature=1.0\n", + "generator = HallucinationGenerator(model=\"gpt-5\", temperature=1.0)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a7f90660", + "metadata": {}, + "outputs": [], + "source": [ + "# The generator can be used with any context-question-answer format\n", + "result = generator.generate(\n", + " context=[\n", + " \"Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily.\"\n", + " ],\n", + " question=\"What is the maximum daily dose of ibuprofen?\",\n", + " answer=\"The maximum daily dose of ibuprofen for adults is 2400mg.\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8760884c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'original_answer': 'The maximum daily dose of ibuprofen for adults is 2400mg.',\n",
+       "    'hallucinated_answer': 'The maximum daily dose of ibuprofen for adults is 3200mg, per a 2016 FDA guideline.',\n",
+       "    'hallucinated_parts': ['3200mg', 'per a 2016 FDA guideline', '2016']\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'original_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 2400mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 3200mg, per a 2016 FDA guideline.'\u001b[0m,\n", + " \u001b[32m'hallucinated_parts'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'3200mg'\u001b[0m, \u001b[32m'per a 2016 FDA guideline'\u001b[0m, \u001b[32m'2016'\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from rich import console\n", + "\n", + "console = console.Console()\n", + "\n", + "console.print(result)" + ] + }, + { + "cell_type": "markdown", + "id": "0b049c56", + "metadata": {}, + "source": [ + "You can easily tune the error types and intensity to your needs.\n", + "\n", + "Currently, the generator supports the following error types:\n", + "- factual = Change facts/entities\n", + "- temporal = Change dates, time periods\n", + "- numerical = Change numbers, quantities\n", + "- relational = Change relationships between entities\n", + "- contextual = Add unrelated context\n", + "- omission = Remove important details\n", + "\n", + "And intensity is a float between 0 and 1, where 0 is hardly noticable and 1 is very obvious" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f09d1a5a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'original_answer': 'The maximum daily dose of ibuprofen for adults is 2400mg.',\n",
+       "    'hallucinated_answer': 'The maximum daily dose of ibuprofen for adults is 3200mg.',\n",
+       "    'hallucinated_parts': ['3200mg']\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'original_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 2400mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 3200mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_parts'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'3200mg'\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Lets try to generate numerical errors\n", + "result = generator.generate(\n", + " context=[\n", + " \"Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily.\"\n", + " ],\n", + " question=\"What is the maximum daily dose of ibuprofen?\",\n", + " answer=\"The maximum daily dose of ibuprofen for adults is 2400mg.\",\n", + " error_types=[\"numerical\"],\n", + ")\n", + "\n", + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1c4add75", + "metadata": {}, + "outputs": [], + "source": [ + "# Lets try with low intensity\n", + "result = generator.generate(\n", + " context=[\n", + " \"Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily.\"\n", + " ],\n", + " question=\"What is the maximum daily dose of ibuprofen?\",\n", + " answer=\"The maximum daily dose of ibuprofen for adults is 2400mg.\",\n", + " error_types=[\"numerical\"],\n", + " intensity=0.1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "4f612a47", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'original_answer': 'The maximum daily dose of ibuprofen for adults is 2400mg.',\n",
+       "    'hallucinated_answer': 'The maximum daily dose of ibuprofen for adults is 2500mg.',\n",
+       "    'hallucinated_parts': ['2500mg']\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'original_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 2400mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 2500mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_parts'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'2500mg'\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "cd939938", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'original_answer': 'The maximum daily dose of ibuprofen for adults is 2400mg.',\n",
+       "    'hallucinated_answer': 'The maximum daily dose of ibuprofen for adults is 3200mg.',\n",
+       "    'hallucinated_parts': ['3200mg']\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'original_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 2400mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_answer'\u001b[0m: \u001b[32m'The maximum daily dose of ibuprofen for adults is 3200mg.'\u001b[0m,\n", + " \u001b[32m'hallucinated_parts'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'3200mg'\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Now lets try to generate factual errors\n", + "result = generator.generate(\n", + " context=[\n", + " \"Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily.\"\n", + " ],\n", + " question=\"What is the maximum daily dose of ibuprofen?\",\n", + " answer=\"The maximum daily dose of ibuprofen for adults is 2400mg.\",\n", + " error_types=[\"factual\"],\n", + " intensity=0.4,\n", + ")\n", + "\n", + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "31e3a3e6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'original_answer': 'Apollo 11 landed on the Moon on July 20, 1969.',\n",
+       "    'hallucinated_answer': 'Apollo 11 landed on the Moon on July 21, 1969.',\n",
+       "    'hallucinated_parts': ['July 21, 1969']\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'original_answer'\u001b[0m: \u001b[32m'Apollo 11 landed on the Moon on July 20, 1969.'\u001b[0m,\n", + " \u001b[32m'hallucinated_answer'\u001b[0m: \u001b[32m'Apollo 11 landed on the Moon on July 21, 1969.'\u001b[0m,\n", + " \u001b[32m'hallucinated_parts'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'July 21, 1969'\u001b[0m\u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Another example: temporal errors\n", + "result = generator.generate(\n", + " context=[\n", + " \"Apollo 11 was the first crewed mission to land on the Moon, touching down on July 20, 1969. Neil Armstrong and Buzz Aldrin spent about 21 hours on the lunar surface.\"\n", + " ],\n", + " question=\"On what date did Apollo 11 land on the Moon?\",\n", + " answer=\"Apollo 11 landed on the Moon on July 20, 1969.\",\n", + " error_types=[\"temporal\"],\n", + " intensity=0.5,\n", + ")\n", + "\n", + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "b5f07a91", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
SimpleBatchResult(\n",
+       "    results=[\n",
+       "        HallucinationDataGeneratorOutput(\n",
+       "            generated_hlcntn_answer='The maximum daily dose of ibuprofen for adults is 2800 mg as recommended since\n",
+       "2019.',\n",
+       "            generated_non_hlcntn_answer='The maximum daily dose of ibuprofen for adults is 2400mg.',\n",
+       "            hlcntn_part=['2800 mg', 'as recommended since 2019']\n",
+       "        ),\n",
+       "        HallucinationDataGeneratorOutput(\n",
+       "            generated_hlcntn_answer='Apollo 11 landed on the Moon on July 21, 1969.',\n",
+       "            generated_non_hlcntn_answer='Apollo 11 landed on the Moon on July 20, 1969.',\n",
+       "            hlcntn_part=['July 21, 1969']\n",
+       "        )\n",
+       "    ],\n",
+       "    failed_indices=[],\n",
+       "    errors=[],\n",
+       "    total_time=16.986872911453247,\n",
+       "    successful_count=2,\n",
+       "    failed_count=0\n",
+       ")\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1;35mSimpleBatchResult\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mresults\u001b[0m=\u001b[1m[\u001b[0m\n", + " \u001b[1;35mHallucinationDataGeneratorOutput\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mgenerated_hlcntn_answer\u001b[0m=\u001b[32m'The maximum daily dose of ibuprofen for adults is 2800 mg as recommended since\u001b[0m\n", + "\u001b[32m2019.'\u001b[0m,\n", + " \u001b[33mgenerated_non_hlcntn_answer\u001b[0m=\u001b[32m'The maximum daily dose of ibuprofen for adults is 2400mg.'\u001b[0m,\n", + " \u001b[33mhlcntn_part\u001b[0m=\u001b[1m[\u001b[0m\u001b[32m'2800 mg'\u001b[0m, \u001b[32m'as recommended since 2019'\u001b[0m\u001b[1m]\u001b[0m\n", + " \u001b[1m)\u001b[0m,\n", + " \u001b[1;35mHallucinationDataGeneratorOutput\u001b[0m\u001b[1m(\u001b[0m\n", + " \u001b[33mgenerated_hlcntn_answer\u001b[0m=\u001b[32m'Apollo 11 landed on the Moon on July 21, 1969.'\u001b[0m,\n", + " \u001b[33mgenerated_non_hlcntn_answer\u001b[0m=\u001b[32m'Apollo 11 landed on the Moon on July 20, 1969.'\u001b[0m,\n", + " \u001b[33mhlcntn_part\u001b[0m=\u001b[1m[\u001b[0m\u001b[32m'July 21, 1969'\u001b[0m\u001b[1m]\u001b[0m\n", + " \u001b[1m)\u001b[0m\n", + " \u001b[1m]\u001b[0m,\n", + " \u001b[33mfailed_indices\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33merrors\u001b[0m=\u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[33mtotal_time\u001b[0m=\u001b[1;36m16\u001b[0m\u001b[1;36m.986872911453247\u001b[0m,\n", + " \u001b[33msuccessful_count\u001b[0m=\u001b[1;36m2\u001b[0m,\n", + " \u001b[33mfailed_count\u001b[0m=\u001b[1;36m0\u001b[0m\n", + "\u001b[1m)\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Hallucinations can be generated in batch as well\n", + "\n", + "\n", + "async def generate_batch(contexts, questions, answers, error_types, intensity):\n", + " generator = HallucinationGenerator(model=\"gpt-5-mini\", temperature=1.0)\n", + " results = await generator.generate_batch_async(\n", + " contexts, questions, answers, error_types, intensity\n", + " )\n", + " return results\n", + "\n", + "\n", + "# Lets try to generate a batch of hallucinations\n", + "contexts = [\n", + " \"Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily.\",\n", + " \"Apollo 11 was the first crewed mission to land on the Moon, touching down on July 20, 1969. Neil Armstrong and Buzz Aldrin spent about 21 hours on the lunar surface.\",\n", + "]\n", + "questions = [\n", + " \"What is the maximum daily dose of ibuprofen?\",\n", + " \"On what date did Apollo 11 land on the Moon?\",\n", + "]\n", + "answers = [\n", + " \"The maximum daily dose of ibuprofen for adults is 2400mg.\",\n", + " \"Apollo 11 landed on the Moon on July 20, 1969.\",\n", + "]\n", + "error_types = [\"numerical\", \"temporal\"]\n", + "intensity = 0.5\n", + "\n", + "results = await generate_batch(contexts, questions, answers, error_types, intensity)\n", + "console.print(results)" + ] + }, + { + "cell_type": "markdown", + "id": "e4c97223", + "metadata": {}, + "source": [ + "## The rag-mini-BioASQ dataset\n", + "\n", + "The rag-mini-BioASQ dataset is a rag dataset of biomedical questions and answers together with their corresponding context.\n", + "\n", + "We can use the HuggingFace `datasets` library to load the dataset.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "44185421", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'question': 'What is the applicability of the No Promoter Left Behind method?',\n",
+       "    'answer': 'No Promoter Left Behind (NPLB) is an efficient, organism-independent method for characterizing \n",
+       "promoter architectures directly from experimentally identified genome-wide TSSs, without relying on known promoter \n",
+       "elements.',\n",
+       "    'context': [\n",
+       "        'Promoters have diverse regulatory architectures and thus activate genes \\ndifferently. For example, some \n",
+       "have a TATA-box, many others do not. Even the \\nones with it can differ in its position relative to the \n",
+       "transcription start site \\n(TSS). No Promoter Left Behind (NPLB) is an efficient, organism-independent \\nmethod for\n",
+       "characterizing such diverse architectures directly from \\nexperimentally identified genome-wide TSSs, without \n",
+       "relying on known promoter \\nelements. As a test case, we show its application in identifying novel \\narchitectures \n",
+       "in the fly genome.\\nAVAILABILITY AND IMPLEMENTATION: Web-server at http://nplb.ncl.res.in Standalone \\nalso at \n",
+       "https://github.com/computationalBiology/NPLB/ (Mac OSX/Linux).\\nCONTACT: l.narlikar@ncl.res.in\\nSUPPLEMENTARY \n",
+       "INFORMATION: Supplementary data are available at Bioinformatics \\nonline.'\n",
+       "    ]\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'question'\u001b[0m: \u001b[32m'What is the applicability of the No Promoter Left Behind method?'\u001b[0m,\n", + " \u001b[32m'answer'\u001b[0m: \u001b[32m'No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-independent method for characterizing \u001b[0m\n", + "\u001b[32mpromoter architectures directly from experimentally identified genome-wide TSSs, without relying on known promoter \u001b[0m\n", + "\u001b[32melements.'\u001b[0m,\n", + " \u001b[32m'context'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[32m'Promoters have diverse regulatory architectures and thus activate genes \\ndifferently. For example, some \u001b[0m\n", + "\u001b[32mhave a TATA-box, many others do not. Even the \\nones with it can differ in its position relative to the \u001b[0m\n", + "\u001b[32mtranscription start site \\n\u001b[0m\u001b[32m(\u001b[0m\u001b[32mTSS\u001b[0m\u001b[32m)\u001b[0m\u001b[32m. No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-independent \\nmethod for\u001b[0m\n", + "\u001b[32mcharacterizing such diverse architectures directly from \\nexperimentally identified genome-wide TSSs, without \u001b[0m\n", + "\u001b[32mrelying on known promoter \\nelements. As a test case, we show its application in identifying novel \\narchitectures \u001b[0m\n", + "\u001b[32min the fly genome.\\nAVAILABILITY AND IMPLEMENTATION: Web-server at http://nplb.ncl.res.in Standalone \\nalso at \u001b[0m\n", + "\u001b[32mhttps://github.com/computationalBiology/NPLB/ \u001b[0m\u001b[32m(\u001b[0m\u001b[32mMac OSX/Linux\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\nCONTACT: l.narlikar@ncl.res.in\\nSUPPLEMENTARY \u001b[0m\n", + "\u001b[32mINFORMATION: Supplementary data are available at Bioinformatics \\nonline.'\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def load_rag_mini_bioasq(split: str = \"train\", filter_min_words: int = 10):\n", + " \"\"\"Load rag-mini-bioasq dataset and prepare for generation.\"\"\"\n", + " try:\n", + " from datasets import load_dataset\n", + " except ImportError:\n", + " raise ImportError(\"datasets package required. Install with: pip install datasets\")\n", + "\n", + " # Load dataset\n", + " qa_dataset = load_dataset(\"enelpol/rag-mini-bioasq\", \"question-answer-passages\")\n", + " corpus_dataset = load_dataset(\"enelpol/rag-mini-bioasq\", \"text-corpus\")\n", + "\n", + " # Create corpus lookup\n", + " corpus_lookup = {item[\"id\"]: item[\"passage\"] for item in corpus_dataset[\"test\"]}\n", + "\n", + " # Process data\n", + " processed_data = []\n", + " for item in qa_dataset[split]:\n", + " passage_ids = item[\"relevant_passage_ids\"]\n", + " context_passages = [corpus_lookup.get(pid, None) for pid in passage_ids]\n", + " context_passages = [p for p in context_passages if p is not None]\n", + "\n", + " # Filter by answer length\n", + " if len(item[\"answer\"].split()) >= filter_min_words:\n", + " processed_data.append(\n", + " {\n", + " \"question\": item[\"question\"],\n", + " \"answer\": item[\"answer\"],\n", + " \"context\": context_passages,\n", + " }\n", + " )\n", + "\n", + " return processed_data\n", + "\n", + "\n", + "# Lets load the dataset\n", + "data = load_rag_mini_bioasq()\n", + "\n", + "# Lets take a look at an example sample\n", + "console.print(data[3])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "db23fb5e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'original_answer': 'No Promoter Left Behind (NPLB) is an efficient, organism-independent method for \n",
+       "characterizing promoter architectures directly from experimentally identified genome-wide TSSs, without relying on \n",
+       "known promoter elements.',\n",
+       "    'hallucinated_answer': 'No Promoter Left Behind (NPLB) is an efficient, organism-specific method for \n",
+       "characterizing promoter architectures from computationally inferred genome-wide TSSs, often leveraging known \n",
+       "promoter elements; it was primarily applied before 2010 and typically analyzes about 8,000 TSSs per dataset.',\n",
+       "    'hallucinated_parts': [\n",
+       "        'organism-specific',\n",
+       "        'from computationally inferred genome-wide TSSs',\n",
+       "        'often leveraging known promoter elements',\n",
+       "        'it was primarily applied before 2010',\n",
+       "        'typically analyzes about 8,000 TSSs per dataset'\n",
+       "    ]\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'original_answer'\u001b[0m: \u001b[32m'No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-independent method for \u001b[0m\n", + "\u001b[32mcharacterizing promoter architectures directly from experimentally identified genome-wide TSSs, without relying on \u001b[0m\n", + "\u001b[32mknown promoter elements.'\u001b[0m,\n", + " \u001b[32m'hallucinated_answer'\u001b[0m: \u001b[32m'No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-specific method for \u001b[0m\n", + "\u001b[32mcharacterizing promoter architectures from computationally inferred genome-wide TSSs, often leveraging known \u001b[0m\n", + "\u001b[32mpromoter elements; it was primarily applied before 2010 and typically analyzes about 8,000 TSSs per dataset.'\u001b[0m,\n", + " \u001b[32m'hallucinated_parts'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[32m'organism-specific'\u001b[0m,\n", + " \u001b[32m'from computationally inferred genome-wide TSSs'\u001b[0m,\n", + " \u001b[32m'often leveraging known promoter elements'\u001b[0m,\n", + " \u001b[32m'it was primarily applied before 2010'\u001b[0m,\n", + " \u001b[32m'typically analyzes about 8,000 TSSs per dataset'\u001b[0m\n", + " \u001b[1m]\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# You can easily use the generator to generate hallucinations for the dataset\n", + "result = generator.generate(\n", + " context=data[3][\"context\"],\n", + " question=data[3][\"question\"],\n", + " answer=data[3][\"answer\"],\n", + ")\n", + "\n", + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "b3972a0c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[\n",
+       "    {\n",
+       "        'prompt': 'Briefly answer the following question:\\nWhat is the applicability of the No Promoter Left Behind\n",
+       "method?\\nBear in mind that your response should be strictly based on the following 1 passages:\\npassage 1: \n",
+       "Promoters have diverse regulatory architectures and thus activate genes \\ndifferently. For example, some have a \n",
+       "TATA-box, many others do not. Even the \\nones with it can differ in its position relative to the transcription \n",
+       "start site \\n(TSS). No Promoter Left Behind (NPLB) is an efficient, organism-independent \\nmethod for \n",
+       "characterizing such diverse architectures directly from \\nexperimentally identified genome-wide TSSs, without \n",
+       "relying on known promoter \\nelements. As a test case, we show its application in identifying novel \\narchitectures \n",
+       "in the fly genome.\\nAVAILABILITY AND IMPLEMENTATION: Web-server at http://nplb.ncl.res.in Standalone \\nalso at \n",
+       "https://github.com/computationalBiology/NPLB/ (Mac OSX/Linux).\\nCONTACT: l.narlikar@ncl.res.in\\nSUPPLEMENTARY \n",
+       "INFORMATION: Supplementary data are available at Bioinformatics \\nonline.\\nIn case the passages do not contain the \n",
+       "necessary information to answer the question, please reply with: \"Unable to answer based on given \n",
+       "passages.\"\\noutput:',\n",
+       "        'answer': 'No Promoter Left Behind (NPLB) is an efficient, organism-independent method for characterizing \n",
+       "promoter architectures directly from experimentally identified genome-wide TSSs, without relying on known promoter \n",
+       "elements.',\n",
+       "        'labels': [],\n",
+       "        'split': 'train',\n",
+       "        'task_type': 'qa'\n",
+       "    },\n",
+       "    {\n",
+       "        'prompt': 'Briefly answer the following question:\\nWhat is the applicability of the No Promoter Left Behind\n",
+       "method?\\nBear in mind that your response should be strictly based on the following 1 passages:\\npassage 1: \n",
+       "Promoters have diverse regulatory architectures and thus activate genes \\ndifferently. For example, some have a \n",
+       "TATA-box, many others do not. Even the \\nones with it can differ in its position relative to the transcription \n",
+       "start site \\n(TSS). No Promoter Left Behind (NPLB) is an efficient, organism-independent \\nmethod for \n",
+       "characterizing such diverse architectures directly from \\nexperimentally identified genome-wide TSSs, without \n",
+       "relying on known promoter \\nelements. As a test case, we show its application in identifying novel \\narchitectures \n",
+       "in the fly genome.\\nAVAILABILITY AND IMPLEMENTATION: Web-server at http://nplb.ncl.res.in Standalone \\nalso at \n",
+       "https://github.com/computationalBiology/NPLB/ (Mac OSX/Linux).\\nCONTACT: l.narlikar@ncl.res.in\\nSUPPLEMENTARY \n",
+       "INFORMATION: Supplementary data are available at Bioinformatics \\nonline.\\nIn case the passages do not contain the \n",
+       "necessary information to answer the question, please reply with: \"Unable to answer based on given \n",
+       "passages.\"\\noutput:',\n",
+       "        'answer': 'No Promoter Left Behind (NPLB) is an efficient, organism-specific method for characterizing \n",
+       "promoter architectures from computationally inferred genome-wide TSSs, often leveraging known promoter elements; it\n",
+       "was primarily applied before 2010 and typically analyzes about 8,000 TSSs per dataset.',\n",
+       "        'labels': [\n",
+       "            {'start': 48, 'end': 65, 'label': 'hallucinated'},\n",
+       "            {'start': 115, 'end': 161, 'label': 'hallucinated'},\n",
+       "            {'start': 163, 'end': 203, 'label': 'hallucinated'},\n",
+       "            {'start': 205, 'end': 241, 'label': 'hallucinated'},\n",
+       "            {'start': 246, 'end': 293, 'label': 'hallucinated'}\n",
+       "        ],\n",
+       "        'split': 'train',\n",
+       "        'task_type': 'qa'\n",
+       "    }\n",
+       "]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m[\u001b[0m\n", + " \u001b[1m{\u001b[0m\n", + " \u001b[32m'prompt'\u001b[0m: \u001b[32m'Briefly answer the following question:\\nWhat is the applicability of the No Promoter Left Behind\u001b[0m\n", + "\u001b[32mmethod?\\nBear in mind that your response should be strictly based on the following 1 passages:\\npassage 1: \u001b[0m\n", + "\u001b[32mPromoters have diverse regulatory architectures and thus activate genes \\ndifferently. For example, some have a \u001b[0m\n", + "\u001b[32mTATA-box, many others do not. Even the \\nones with it can differ in its position relative to the transcription \u001b[0m\n", + "\u001b[32mstart site \\n\u001b[0m\u001b[32m(\u001b[0m\u001b[32mTSS\u001b[0m\u001b[32m)\u001b[0m\u001b[32m. No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-independent \\nmethod for \u001b[0m\n", + "\u001b[32mcharacterizing such diverse architectures directly from \\nexperimentally identified genome-wide TSSs, without \u001b[0m\n", + "\u001b[32mrelying on known promoter \\nelements. As a test case, we show its application in identifying novel \\narchitectures \u001b[0m\n", + "\u001b[32min the fly genome.\\nAVAILABILITY AND IMPLEMENTATION: Web-server at http://nplb.ncl.res.in Standalone \\nalso at \u001b[0m\n", + "\u001b[32mhttps://github.com/computationalBiology/NPLB/ \u001b[0m\u001b[32m(\u001b[0m\u001b[32mMac OSX/Linux\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\nCONTACT: l.narlikar@ncl.res.in\\nSUPPLEMENTARY \u001b[0m\n", + "\u001b[32mINFORMATION: Supplementary data are available at Bioinformatics \\nonline.\\nIn case the passages do not contain the \u001b[0m\n", + "\u001b[32mnecessary information to answer the question, please reply with: \"Unable to answer based on given \u001b[0m\n", + "\u001b[32mpassages.\"\\noutput:'\u001b[0m,\n", + " \u001b[32m'answer'\u001b[0m: \u001b[32m'No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-independent method for characterizing \u001b[0m\n", + "\u001b[32mpromoter architectures directly from experimentally identified genome-wide TSSs, without relying on known promoter \u001b[0m\n", + "\u001b[32melements.'\u001b[0m,\n", + " \u001b[32m'labels'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[32m'split'\u001b[0m: \u001b[32m'train'\u001b[0m,\n", + " \u001b[32m'task_type'\u001b[0m: \u001b[32m'qa'\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\n", + " \u001b[32m'prompt'\u001b[0m: \u001b[32m'Briefly answer the following question:\\nWhat is the applicability of the No Promoter Left Behind\u001b[0m\n", + "\u001b[32mmethod?\\nBear in mind that your response should be strictly based on the following 1 passages:\\npassage 1: \u001b[0m\n", + "\u001b[32mPromoters have diverse regulatory architectures and thus activate genes \\ndifferently. For example, some have a \u001b[0m\n", + "\u001b[32mTATA-box, many others do not. Even the \\nones with it can differ in its position relative to the transcription \u001b[0m\n", + "\u001b[32mstart site \\n\u001b[0m\u001b[32m(\u001b[0m\u001b[32mTSS\u001b[0m\u001b[32m)\u001b[0m\u001b[32m. No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-independent \\nmethod for \u001b[0m\n", + "\u001b[32mcharacterizing such diverse architectures directly from \\nexperimentally identified genome-wide TSSs, without \u001b[0m\n", + "\u001b[32mrelying on known promoter \\nelements. As a test case, we show its application in identifying novel \\narchitectures \u001b[0m\n", + "\u001b[32min the fly genome.\\nAVAILABILITY AND IMPLEMENTATION: Web-server at http://nplb.ncl.res.in Standalone \\nalso at \u001b[0m\n", + "\u001b[32mhttps://github.com/computationalBiology/NPLB/ \u001b[0m\u001b[32m(\u001b[0m\u001b[32mMac OSX/Linux\u001b[0m\u001b[32m)\u001b[0m\u001b[32m.\\nCONTACT: l.narlikar@ncl.res.in\\nSUPPLEMENTARY \u001b[0m\n", + "\u001b[32mINFORMATION: Supplementary data are available at Bioinformatics \\nonline.\\nIn case the passages do not contain the \u001b[0m\n", + "\u001b[32mnecessary information to answer the question, please reply with: \"Unable to answer based on given \u001b[0m\n", + "\u001b[32mpassages.\"\\noutput:'\u001b[0m,\n", + " \u001b[32m'answer'\u001b[0m: \u001b[32m'No Promoter Left Behind \u001b[0m\u001b[32m(\u001b[0m\u001b[32mNPLB\u001b[0m\u001b[32m)\u001b[0m\u001b[32m is an efficient, organism-specific method for characterizing \u001b[0m\n", + "\u001b[32mpromoter architectures from computationally inferred genome-wide TSSs, often leveraging known promoter elements; it\u001b[0m\n", + "\u001b[32mwas primarily applied before 2010 and typically analyzes about 8,000 TSSs per dataset.'\u001b[0m,\n", + " \u001b[32m'labels'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[1m{\u001b[0m\u001b[32m'start'\u001b[0m: \u001b[1;36m48\u001b[0m, \u001b[32m'end'\u001b[0m: \u001b[1;36m65\u001b[0m, \u001b[32m'label'\u001b[0m: \u001b[32m'hallucinated'\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\u001b[32m'start'\u001b[0m: \u001b[1;36m115\u001b[0m, \u001b[32m'end'\u001b[0m: \u001b[1;36m161\u001b[0m, \u001b[32m'label'\u001b[0m: \u001b[32m'hallucinated'\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\u001b[32m'start'\u001b[0m: \u001b[1;36m163\u001b[0m, \u001b[32m'end'\u001b[0m: \u001b[1;36m203\u001b[0m, \u001b[32m'label'\u001b[0m: \u001b[32m'hallucinated'\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\u001b[32m'start'\u001b[0m: \u001b[1;36m205\u001b[0m, \u001b[32m'end'\u001b[0m: \u001b[1;36m241\u001b[0m, \u001b[32m'label'\u001b[0m: \u001b[32m'hallucinated'\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[1m{\u001b[0m\u001b[32m'start'\u001b[0m: \u001b[1;36m246\u001b[0m, \u001b[32m'end'\u001b[0m: \u001b[1;36m293\u001b[0m, \u001b[32m'label'\u001b[0m: \u001b[32m'hallucinated'\u001b[0m\u001b[1m}\u001b[0m\n", + " \u001b[1m]\u001b[0m,\n", + " \u001b[32m'split'\u001b[0m: \u001b[32m'train'\u001b[0m,\n", + " \u001b[32m'task_type'\u001b[0m: \u001b[32m'qa'\u001b[0m\n", + " \u001b[1m}\u001b[0m\n", + "\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# You can easily convert this to the format LettuceDetect uses for training\n", + "from lettucedetect.detectors.prompt_utils import PromptUtils\n", + "\n", + "train_data = []\n", + "\n", + "# Add the non-hallucinated sample\n", + "train_data.append(\n", + " {\n", + " \"prompt\": PromptUtils.format_context(data[3][\"context\"], data[3][\"question\"], lang=\"en\"),\n", + " \"answer\": result[\"original_answer\"],\n", + " \"labels\": [],\n", + " \"split\": \"train\",\n", + " \"task_type\": \"qa\",\n", + " }\n", + ")\n", + "\n", + "hallucinated_labels = []\n", + "for part in result[\"hallucinated_parts\"]:\n", + " start = result[\"hallucinated_answer\"].find(part)\n", + " if start != -1:\n", + " hallucinated_labels.append(\n", + " {\"start\": start, \"end\": start + len(part), \"label\": \"hallucinated\"}\n", + " )\n", + "# Add the hallucinated sample\n", + "train_data.append(\n", + " {\n", + " \"prompt\": PromptUtils.format_context(data[3][\"context\"], data[3][\"question\"], lang=\"en\"),\n", + " \"answer\": result[\"hallucinated_answer\"],\n", + " \"labels\": hallucinated_labels,\n", + " \"split\": \"train\",\n", + " \"task_type\": \"qa\",\n", + " }\n", + ")\n", + "\n", + "console.print(train_data)" + ] + }, + { + "cell_type": "markdown", + "id": "3cdc1dbd", + "metadata": {}, + "source": [ + "## Save and train\n", + "\n", + "Now you can save the data and train a model. First lets save the data.\n", + "\n", + "```python\n", + "import json\n", + "\n", + "with open('train_data.json', 'w') as f:\n", + " json.dump(train_data, f)\n", + "```\n", + "\n", + "Now you can train a model.\n", + "\n", + "```bash\n", + "python scripts/train.py \\\n", + " --ragtruth-path train_data.json \\\n", + " --model-name jhu-clsp/ettin-encoder-68m \\\n", + " --output-dir output/hallucination_detector \\\n", + " --batch-size 4 \\\n", + " --epochs 6 \\\n", + " --learning-rate 1e-5 \n", + "```\n", + "\n", + "**And that's it!** You have a hallucination detector that you can use to detect hallucinations in your data.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8f275af5", + "metadata": {}, + "source": [ + "For the published models, we have generated **1500** samples from the rag-mini-bioasq dataset (3000 samples together with the non-hallucinated ones). We've used the `gpt-oss-120b` model for the training data generation. We haven't specified direct error types, and used the default intensity of 0.3.\n", + "\n", + "For the test set, we have generated **300** hallucinated samples (600 samples together with the non-hallucinated ones). We've used the `gpt-5` model for the generation to ensure the quality of the hallucinations for the test set.\n", + "\n", + "For large scale generation, use our script:\n", + "\n", + "```bash\n", + "python scripts/generate_synthetic_data.py \\\\\n", + " --dataset rag-mini-bioasq \\\\\n", + " --split train \\\\\n", + " --num-samples 100 \\\\\n", + " --model gpt-4o-mini \\\\\n", + " --output data/synthetic_train.json\n", + "```\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "fbc37ff5", + "metadata": {}, + "source": [ + "## End-to-End Workflow\n", + "\n", + "```bash\n", + "# Step 1: Generate synthetic training data\n", + "python scripts/generate_synthetic_data.py \\\n", + " --dataset rag-mini-bioasq \\\n", + " --num-samples 2000 \\\n", + " --model gpt-4o-mini \\\n", + " --batch-size 10 \\\n", + " --output data/synthetic_large.json\n", + "\n", + "# Step 2: Train TinyLettuce model\n", + "python scripts/train.py \\\n", + " --ragtruth-path data/train_combined_large.json \\\n", + " --model-name jhu-clsp/ettin-encoder-17m \\\n", + " --output-dir output/tinylettuce_17m \\\n", + " --batch-size 8 \\\n", + " --epochs 3\n", + "\n", + "# Step 3: Deploy on CPU for real-time inference\n", + "python scripts/start_api.py prod --model output/tinylettuce_17m\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "2b24eab7", + "metadata": {}, + "source": [ + "## Bonus\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9a75b564", + "metadata": {}, + "outputs": [], + "source": [ + "# We have implemented a triplet-based hallucination detection model that you can use the same way as the standard lettucecedetect models.\n", + "\n", + "from lettucedetect.models.inference import HallucinationDetector\n", + "from lettucedetect.ragfactchecker import RAGFactChecker\n", + "\n", + "detector = HallucinationDetector(\n", + " method=\"rag_fact_checker\",\n", + ")\n", + "\n", + "fact_checker = RAGFactChecker()" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "09d6f585", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[['the capital of France', 'is', 'Paris']]\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'the capital of France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'Paris'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Get triplets for a sample\n", + "triplets = fact_checker.generate_triplets(\"The capital of France is Paris.\")\n", + "console.print(triplets)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "49953813", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'answer_triplets': [['France', 'is', 'a country in Europe']],\n",
+       "    'reference_triplets': [['France', 'is', 'a country in Asia']],\n",
+       "    'comparison': {\n",
+       "        'fact_check_results': {0: False},\n",
+       "        'raw_output': FactCheckerOutput(fact_check_prediction_binary={0: False})\n",
+       "    }\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'answer_triplets'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'a country in Europe'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[32m'reference_triplets'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'a country in Asia'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[32m'comparison'\u001b[0m: \u001b[1m{\u001b[0m\n", + " \u001b[32m'fact_check_results'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1;36m0\u001b[0m: \u001b[3;91mFalse\u001b[0m\u001b[1m}\u001b[0m,\n", + " \u001b[32m'raw_output'\u001b[0m: \u001b[1;35mFactCheckerOutput\u001b[0m\u001b[1m(\u001b[0m\u001b[33mfact_check_prediction_binary\u001b[0m=\u001b[1m{\u001b[0m\u001b[1;36m0\u001b[0m: \u001b[3;91mFalse\u001b[0m\u001b[1m}\u001b[0m\u001b[1m)\u001b[0m\n", + " \u001b[1m}\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "compare = fact_checker.analyze_text_pair(\n", + " \"France is a country in Europe.\", \"France is a country in Asia.\"\n", + ")\n", + "console.print(compare)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "77bfe6e4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
{\n",
+       "    'spans': [\n",
+       "        {\n",
+       "            'start': 0,\n",
+       "            'end': 31,\n",
+       "            'text': 'The capital of France is Berlin',\n",
+       "            'confidence': 0.9,\n",
+       "            'triplet': ['the capital of France', 'is', 'Berlin']\n",
+       "        }\n",
+       "    ],\n",
+       "    'triplets': {\n",
+       "        'answer': [['the capital of France', 'is', 'Berlin']],\n",
+       "        'context': [['The capital of France', 'is', 'Paris']],\n",
+       "        'hallucinated': [['the capital of France', 'is', 'Berlin']]\n",
+       "    },\n",
+       "    'fact_check_results': {0: False}\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'spans'\u001b[0m: \u001b[1m[\u001b[0m\n", + " \u001b[1m{\u001b[0m\n", + " \u001b[32m'start'\u001b[0m: \u001b[1;36m0\u001b[0m,\n", + " \u001b[32m'end'\u001b[0m: \u001b[1;36m31\u001b[0m,\n", + " \u001b[32m'text'\u001b[0m: \u001b[32m'The capital of France is Berlin'\u001b[0m,\n", + " \u001b[32m'confidence'\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", + " \u001b[32m'triplet'\u001b[0m: \u001b[1m[\u001b[0m\u001b[32m'the capital of France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'Berlin'\u001b[0m\u001b[1m]\u001b[0m\n", + " \u001b[1m}\u001b[0m\n", + " \u001b[1m]\u001b[0m,\n", + " \u001b[32m'triplets'\u001b[0m: \u001b[1m{\u001b[0m\n", + " \u001b[32m'answer'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'the capital of France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'Berlin'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[32m'context'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'The capital of France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'Paris'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m,\n", + " \u001b[32m'hallucinated'\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m[\u001b[0m\u001b[32m'the capital of France'\u001b[0m, \u001b[32m'is'\u001b[0m, \u001b[32m'Berlin'\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\n", + " \u001b[1m}\u001b[0m,\n", + " \u001b[32m'fact_check_results'\u001b[0m: \u001b[1m{\u001b[0m\u001b[1;36m0\u001b[0m: \u001b[3;91mFalse\u001b[0m\u001b[1m}\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# You can use it for detecting hallucinations in your data\n", + "result = detector.predict(\n", + " context=\"The capital of France is Paris.\",\n", + " question=\"What is the capital of France?\",\n", + " answer=\"The capital of France is Berlin.\",\n", + " output_format=\"detailed\",\n", + ")\n", + "console.print(result)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "143e383a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lettuce", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/TINYLETTUCE.md b/docs/TINYLETTUCE.md new file mode 100644 index 0000000..55fafb7 --- /dev/null +++ b/docs/TINYLETTUCE.md @@ -0,0 +1,406 @@ +# 🥬 TinyLettuce: Efficient Hallucination Detection with 17–68M Encoders + +

+ TinyLettuce Detective +
+ Small, task‑specialized encoders trained on synthetic data +

+ +--- + +We present **TinyLettuce**, our approach to efficient hallucination detection. By training tiny Ettin encoders (17-68M parameters), we achieve better accuracy than billion-parameter LLM judges while running in real-time on CPU. + +## TL;DR + +- We're releasing a pipeline for generating synthetic training data for hallucination detection and training tiny Ettin encoders on it. +- **TinyLettuce‑17M** (17M parameters) reaches **90.87% F1** 🎯 on synthetic test data, outperforming GPT‑5‑mini (83.69%), GPT‑OSS‑120B (83.38%), and Qwen3‑235B (79.84%) +- Runs in **real-time on CPU** with low latency and large throughput +- **Synthetic data generation** creates training data **significantly cheaper** than manual annotation +- Complete **end‑to‑end pipeline** for domain-specific model training - generate data and train in minutes +- All models and code are **MIT licensed** and ready for production deployment + +--- + +## Quick Links + +- **GitHub**: [github.com/KRLabsOrg/LettuceDetect](https://github.com/KRLabsOrg/LettuceDetect) +- **PyPI**: [pypi.org/project/lettucedetect](https://pypi.org/project/lettucedetect/) +- **Hugging Face Models**: + - [TinyLettuce Collection](https://huggingface.co/collections/KRLabsOrg/tinylettuce-68b42a66b8b6aaa4bf287bf4) +- **Notebook/Demo**: [TinyLettuce end‑to‑end](https://github.com/KRLabsOrg/LettuceDetect/blob/main/demo/tinylettuce.ipynb) +- **Ettin Paper (LightOn)**: https://huggingface.co/papers/2507.11412 + +--- + +## Quickstart + +Install: + +```bash +pip install lettucedetect +``` + +### Detect Hallucinations (Real-time CPU) + +Take one of our pre-trained models and use it for detecting hallucinations in your data: + +```python +from lettucedetect.models.inference import HallucinationDetector + +# Load tiny but powerful model +detector = HallucinationDetector( + method="transformer", + model_path="KRLabsOrg/tinylettuce-ettin-17m-en-v1" +) + +# Detect hallucinations in medical context +spans = detector.predict( + context=[ + "Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily." + ], + question="What is the maximum daily dose of ibuprofen?", + answer="The maximum daily dose of ibuprofen for adults is 3200mg.", + output_format="spans", +) +print(spans) +# Output: [{"start": 51, "end": 57, "text": "3200mg"}] +``` + +### Generate Synthetic Training Data + +With **lettucedetect**, you can create training data automatically with controllable error types using the HallucinationGenerator class. Generate domain-specific training data with just a few lines of code while controlling error types and intensity. + +```python +from lettucedetect import HallucinationGenerator + +# Initialize generator (GPT‑5 requires temperature=1.0) +generator = HallucinationGenerator(model="gpt-5-mini", temperature=1.0) + +# Generate numerical error +result_medical = generator.generate( + context=[ + "Ibuprofen is an NSAID that reduces inflammation and pain. The typical adult dose is 400-600mg every 6-8 hours, not exceeding 2400mg daily." + ], + question="What is the maximum daily dose of ibuprofen?", + answer="The maximum daily dose of ibuprofen for adults is 2400mg.", + error_types=["numerical"], + intensity=0.4, +) +print(f"Original: {result_medical['original_answer']}") +print(f"Hallucinated: {result_medical['hallucinated_answer']}") + +# Generate temporal error +result_historical = generator.generate( + context=[ + "Apollo 11 was the first crewed mission to land on the Moon, touching down on July 20, 1969." + ], + question="On what date did Apollo 11 land on the Moon?", + answer="Apollo 11 landed on the Moon on July 20, 1969.", + error_types=["temporal"], + intensity=0.5, +) +print(f"Original: {result_historical['original_answer']}") +print(f"Hallucinated: {result_historical['hallucinated_answer']}") +``` + +**See the notebook for complete end‑to‑end examples**: [TinyLettuce notebook](https://github.com/KRLabsOrg/LettuceDetect/blob/main/demo/tinylettuce.ipynb) + +--- + +## Motivation + +RAG systems require hallucination detection, but current solutions have painful trade-offs between accuracy, cost, and speed. + +**Current hallucination detection approaches:** + +1. **Prompt-based detectors** - Use LLM APIs for zero/few-shot detection + - Can be expensive for large-scale production deployments + - Latency issues (2-10s per request) unsuitable for real-time use + - Multiple API calls per detection increase costs + +2. **Fine-tuned LLM detectors** - Large models (Llama-2-13B, Llama-3-8B) fine-tuned for detection + - High accuracy but resource-intensive to train and deploy + - Need GPU clusters, slow inference, high operational costs + +3. **Encoder-based detectors** - BERT-style models for token classification + - Fast and efficient but historically limited by short context (512 tokens) + - Can't handle typical RAG contexts which often exceed this limit + +**LettuceDetect's novel approach**: We solved the context problem by leveraging ModernBERT's 8K token capacity, achieving better accuracy than fine-tuned LLMs at a fraction of the computational cost. This shows that encoder-based detection could work at scale. + +**Can we go even smaller and faster?** + +**Enter TinyLettuce with Ettin encoders**: Ettin encoders released by LightOn (see the [HF collection](https://huggingface.co/collections/jhu-clsp/encoders-vs-decoders-the-ettin-suite-686303e16142257eed8e6aeb) and [paper](https://huggingface.co/papers/2507.11412)) are small, long‑context encoders with modern architectures. These lightweight transformers (17–68M parameters) support long contexts and are optimized for classification and retrieval, focusing on efficient representation learning for fast, accurate detection. + +**The key insight**: With the right synthetic training data, a 17M parameter Ettin encoder can outperform 235B parameter LLMs at hallucination detection while running real-time on CPU. TinyLettuce makes it easy to use small models for hallucination detection by making it accessible, fast, and cost-effective for any deployment. + +## Approach + +**Specialized training data can matter more than parameter count**: + +1. **Generate synthetic data** using LettuceDetect's HallucinationGenerator class - no manual annotation needed +2. **Train tiny Ettin encoders** (17M-68M parameters) on this specialized data +3. **Deploy on CPU** for real-time inference with low latency and high throughput +4. **Scale effortlessly** - no GPU clusters or API limits (it's just a trained model) + +## Synthetic Hallucination Data + +You can use LettuceDetect's HallucinationGenerator class to generate training pairs automatically at scale. + +### Production-Scale Generation + +For large datasets, use our generation script: + +```bash +# Generate 2,000 training samples (1,000 hallucinated + 1,000 non-hallucinated) +python scripts/generate_synthetic_data.py \ + --dataset rag-mini-bioasq \ + --num-samples 2000 \ + --model gpt-oss-120b \ + --output-format ragtruth \ + --output data/synthetic_2k.json +``` + +### Data Schema (RAGTruth format) + +Minimal entry used for training: + +```json +{ + "prompt": "...", + "answer": "...", + "labels": [{"start": 31, "end": 71, "label": "hallucinated"}], + "split": "train", + "task_type": "qa", + "dataset": "synthetic", + "language": "en" +} +``` + +## TinyLettuce Models (Ettin Encoders) + +Built on the **Ettin encoder** (LightOn) — a lightweight, efficient transformer optimized for classification — these models achieve strong accuracy with low latency. + +### Model Family + +| Model | Parameters | Context Length | Key Advantage | +|-------|------------|----------------|---------------| +| **Ettin-17M** | 17 million | 8K tokens | Edge deployment | +| **Ettin-32M** | 32 million | 8K tokens | Very fast, good accuracy | +| **Ettin-68M** | 68 million | 8K tokens | Higher accuracy, still very fast | + +Why Ettin encoders work well: +- 8K token context windows (longer than most inputs) +- Modern transformer design (RoPE, GLU activations) +- Optimized for token classification, not generation +- Efficient CPU inference without GPU overhead (smaller than ModernBERT models) + +### Data & Training Setup (Published Models) + +We show two training approaches for TinyLettuce models: + +**1. General-Purpose Models (RAGTruth + Synthetic):** +- Base: Original RAGTruth dataset for broad hallucination detection capabilities +- Synthetic augmentation: 3,000 total samples (1,500 hallucinated + 1,500 non-hallucinated) from `enelpol/rag-mini-bioasq` generated using GPT-OSS-120b +- Training recipe: Ettin encoders fine-tuned on combined RAGTruth + synthetic data for robust performance across domains + +**2. Domain-Specific Models (Synthetic-Only):** +- Pure synthetic data generation for targeted domain applications +- Controllable error types and intensity for specific use cases +- Faster training and deployment for specialized scenarios +- Trained on 3,000 synthetic samples (1,500 hallucinated + 1,500 non-hallucinated) + +### Training Hyperparameters (Released Models) + +- Optimizer: AdamW; learning rate `1e-5`; weight decay `0.01`. +- Epochs: 5 +- Batch size: 16; max sequence length: 4096 tokens. +- Tokenization: `AutoTokenizer`; label pad `-100`; `DataCollatorForTokenClassification`. + +## Results + +We trained several variants of Ettin encoders on synthetic data and tested them against larger scale LLM judges and fine-tuned encoders. + +### Synthetic Data Evaluation (example-level) + +Metrics are computed at example level (answer contains any hallucination vs none). Precision/recall/F1 reflect this binary decision; thresholds and post‑processing can affect absolute values. + +**Test Set**: 600 synthetic examples (300 hallucinated + 300 non-hallucinated) generated with GPT-5-mini for fair evaluation. + +*When trained and evaluated on domain-specific synthetic data, tiny models dominate (LettuceDetect-base shown without synthetic training):* + +| Model | Parameters | Precision (%) | Recall (%) | F1 (%) | Hardware | +|-------|------------|---------------|------------|---------|----------| +| **TinyLettuce-17M** | **17M** | 84.56 | 98.21 | **90.87** | **CPU** | +| **TinyLettuce-32M** | **32M** | 80.36 | 99.10 | 88.76 | **CPU** | +| **TinyLettuce-68M** | **68M** | **89.54** | 95.96 | **92.64** | **CPU** | +| LettuceDetect-base (ModernBERT) | 150M | 79.06 | 98.21 | 87.60 | GPU | +| GPT-5-mini | ~200B | 71.95 | **100.00** | 83.69 | API/GPU | +| GPT-OSS-120B | 120B | 72.21 | 98.64 | 83.38 | GPU | +| Qwen3-235B | 235B | 66.74 | 99.32 | 79.84 | GPU | + +### RAGTruth Benchmark Evaluation (example-level) + +*Strong performance on standard benchmarks (Ettin models trained on RAGTruth + synthetic data):* + +| Model | Parameters | F1 (%) | +|-------|------------|---------| +| **TinyLettuce-17M** | **17M** | 68.52 | +| **TinyLettuce-32M** | **32M** | 72.15 | +| **TinyLettuce-68M** | **68M** | **74.97** | +| LettuceDetect-base (ModernBERT) | 150M | 76.07 | +| LettuceDetect-large (ModernBERT) | 395M | **79.22** | +| Llama-2-13B (RAGTruth FT) | 13B | 78.70 | + +TinyLettuce Ettin models demonstrate impressive performance given their compact size. These models are trained on both RAGTruth and synthetic data, achieving strong results across both evaluation sets. While ModernBERT models achieve slightly higher accuracy, TinyLettuce offers 6-23x parameter reduction with competitive results, making them ideal for resource-constrained deployments. + +Baselines and judges: we compare against commonly used LLM judges (e.g., GPT‑5‑mini, GPT‑OSS‑120B, Qwen3‑235B) and fine‑tuned encoders/decoders reported in RAGTruth and follow‑up work (e.g., Llama‑2‑13B FT). Beyond benchmarks, deployment characteristics often determine real‑world value. + +### Evaluation Method + +- Span construction from tokens: threshold 0.5 on token hallucination prob; contiguous tokens merged into spans. +- Reported F1 is example‑level unless explicitly noted. +- Example command: + +```bash +python scripts/evaluate.py \ + --model_path output/tinylettuce_68m \ + --data_path data/ragtruth/ragtruth_data.json \ + --evaluation_type example_level +``` + +## Real‑Time CPU Inference + +TinyLettuce's biggest advantage isn't just accuracy — it's accessibility ⚡. These models run in real time on standard CPUs, making hallucination detection practical to deploy widely. + +### End-to-End Workflow + +```bash +# Step 1: Generate synthetic training data +python scripts/generate_synthetic_data.py \ + --dataset rag-mini-bioasq \ + --num-samples 50000 \ + --model gpt-oss-120b \ + --batch-size 50 \ + --output-format ragtruth \ + --output data/synthetic_large.json + +# Step 2: Train TinyLettuce model +python scripts/train.py \ + --ragtruth-path data/train_combined_large.json \ + --model-name jhu-clsp/ettin-encoder-17m \ + --output-dir output/tinylettuce_17m \ + --batch-size 8 \ + --epochs 3 + +# Step 3: Deploy on CPU for real-time inference +python scripts/start_api.py prod --model output/tinylettuce_17m +``` + + +--- + +## Bonus: Triplet‑Based RAGFactChecker + +We have implemented a triplet-based hallucination detection model that you can use the same way as the standard lettucedetect models. + +Generate triplets from any text: +```python +from lettucedetect.models.inference import HallucinationDetector +from lettucedetect.ragfactchecker import RAGFactChecker + +detector = HallucinationDetector( + method="rag_fact_checker", +) + +rag = RAGFactChecker(model="gpt-5-mini") # requires OPENAI_API_KEY +triplets = rag.generate_triplets("Paris is the capital of France.") +print(triplets) # e.g., [["Paris", "is_capital_of", "France"]] +``` + +Compare triplets against each other: +```python +compare = fact_checker.analyze_text_pair( + "France is a country in Europe.", "France is a country in Asia." +) +print(compare) +#{ +# 'answer_triplets': [['France', 'is', 'a country in Europe']], +# 'reference_triplets': [['France', 'is', 'a country in Asia']], +# 'comparison': { +# 'fact_check_results': {0: False}, +# 'raw_output': FactCheckerOutput(fact_check_prediction_binary={0: False}) +# } +#} +``` + +Use it for detecting hallucinations in your data: +```python +# You can use it for detecting hallucinations in your data +result = detector.predict( + context="The capital of France is Paris.", + question="What is the capital of France?", + answer="The capital of France is Berlin.", + output_format="detailed", +) +print(result) +#{ +# 'spans': [ +# { +# 'start': 0, +# 'end': 31, +# 'text': 'The capital of France is Berlin', +# 'confidence': 0.9, +# 'triplet': ['the capital of France', 'is', 'Berlin'] +# } +# ], +# 'triplets': { +# 'answer': [['the capital of France', 'is', 'Berlin']], +# 'context': [['The capital of France', 'is', 'Paris']], +# 'hallucinated': [['the capital of France', 'is', 'Berlin']] +# }, +# 'fact_check_results': {0: False} +#} +``` + +This complements token/span detectors with interpretable, fact-level explanations. + +--- + +## Limitations & Notes + +- Results labeled “synthetic” reflect evaluation on generated data; real‑world performance depends on domain match. Consider adding a small, manually curated eval set. +- Baselines: we report GPT‑5‑mini and open‑source LLM baselines where available; prompt configuration impacts absolute scores. +- Metrics: synthetic and RAGTruth F1 are example-level unless otherwise noted; thresholds and post‑processing influence outcomes. + +--- + +## Citation + +If you find this work useful, please cite it as follows: + +```bibtex +@misc{Kovacs:2025, + title={LettuceDetect: A Hallucination Detection Framework for RAG Applications}, + author={Ádám Kovács and Gábor Recski}, + year={2025}, + eprint={2502.17125}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2502.17125}, +} +``` + +--- + +## References + +[1] [RAGTruth: A Dataset for Hallucination Detection in Retrieval-Augmented Generation](https://aclanthology.org/2024.acl-long.585/) + +[2] [LettuceDetect: A Hallucination Detection Framework for RAG Applications](https://arxiv.org/abs/2502.17125) + +[3] [Ettin: Encoder Models by LightOn (paper)](https://huggingface.co/papers/2507.11412) + +[4] [Ettin Encoder Models (HF models)](https://huggingface.co/jhu-clsp/ettin-encoder-68m) + +[5] [RAGFactChecker](https://github.com/KRLabsOrg/RAGFactChecker) diff --git a/lettucedetect/__init__.py b/lettucedetect/__init__.py index e69de29..928b8f6 100644 --- a/lettucedetect/__init__.py +++ b/lettucedetect/__init__.py @@ -0,0 +1,27 @@ +"""LettuceDetect: Hallucination detection and generation for RAG systems.""" + +# Main detection interface +# Core data structures +from lettucedetect.datasets.hallucination_dataset import ( + HallucinationData, + HallucinationDataset, + HallucinationSample, +) + +# Generation interface +from lettucedetect.models.generation import HallucinationGenerator +from lettucedetect.models.inference import HallucinationDetector + +# Direct RAGFactChecker access for advanced users +from lettucedetect.ragfactchecker import RAGFactChecker + +__version__ = "0.1.7" + +__all__ = [ + "HallucinationData", + "HallucinationDataset", + "HallucinationDetector", + "HallucinationGenerator", + "HallucinationSample", + "RAGFactChecker", # Direct access to triplet functionality +] diff --git a/lettucedetect/detectors/__init__.py b/lettucedetect/detectors/__init__.py index 4b0df18..86e1189 100644 --- a/lettucedetect/detectors/__init__.py +++ b/lettucedetect/detectors/__init__.py @@ -3,11 +3,13 @@ from lettucedetect.detectors.base import BaseDetector from lettucedetect.detectors.factory import make_detector as _make_detector from lettucedetect.detectors.llm import LLMDetector +from lettucedetect.detectors.rag_fact_checker import RAGFactCheckerDetector from lettucedetect.detectors.transformer import TransformerDetector __all__ = [ "BaseDetector", "LLMDetector", + "RAGFactCheckerDetector", "TransformerDetector", "_make_detector", ] diff --git a/lettucedetect/detectors/factory.py b/lettucedetect/detectors/factory.py index 430e616..fc13219 100644 --- a/lettucedetect/detectors/factory.py +++ b/lettucedetect/detectors/factory.py @@ -10,10 +10,10 @@ def make_detector(method: str, **kwargs) -> BaseDetector: """Create a detector of the requested type with the given parameters. - :param method: One of "transformer" or "llm". + :param method: One of "transformer", "llm", or "rag_fact_checker". :param kwargs: Passed to the concrete detector constructor. :return: A concrete detector instance. - :raises ValueError: If method is not one of "transformer" or "llm". + :raises ValueError: If method is not supported. """ if method == "transformer": from lettucedetect.detectors.transformer import TransformerDetector @@ -23,5 +23,11 @@ def make_detector(method: str, **kwargs) -> BaseDetector: from lettucedetect.detectors.llm import LLMDetector return LLMDetector(**kwargs) + elif method == "rag_fact_checker": + from lettucedetect.detectors.rag_fact_checker import RAGFactCheckerDetector + + return RAGFactCheckerDetector(**kwargs) else: - raise ValueError(f"Unknown detector method: {method}. Use one of: transformer, llm") + raise ValueError( + f"Unknown detector method: {method}. Use one of: transformer, llm, rag_fact_checker" + ) diff --git a/lettucedetect/detectors/llm.py b/lettucedetect/detectors/llm.py index c24208a..4c4b010 100644 --- a/lettucedetect/detectors/llm.py +++ b/lettucedetect/detectors/llm.py @@ -12,25 +12,26 @@ from lettucedetect.detectors.cache import CacheManager from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils -ANNOTATE_SCHEMA = [ - { - "type": "function", - "function": { - "name": "annotate", - "description": "Return hallucinated substrings from the answer relative to the source.", - "parameters": { - "type": "object", - "properties": { - "hallucination_list": { - "type": "array", - "items": {"type": "string"}, - } - }, - "required": ["hallucination_list"], +# JSON schema for structured response format +HALLUCINATION_SCHEMA = { + "type": "json_schema", + "json_schema": { + "name": "hallucination_detection", + "schema": { + "type": "object", + "properties": { + "hallucination_list": { + "type": "array", + "items": {"type": "string"}, + "description": "List of exact text spans from the answer that are hallucinated", + } }, + "required": ["hallucination_list"], + "additionalProperties": False, }, - } -] + "strict": True, + }, +} class LLMDetector: @@ -174,11 +175,10 @@ def _predict(self, prompt: str, answer: str) -> list[dict]: # Use the full LLM prompt here, not the raw context {"role": "user", "content": llm_prompt}, ], - tools=ANNOTATE_SCHEMA, - tool_choice={"type": "function", "function": {"name": "annotate"}}, + response_format=HALLUCINATION_SCHEMA, temperature=self.temperature, ) - cached = resp.choices[0].message.tool_calls[0].function.arguments + cached = resp.choices[0].message.content self.cache.set(cache_key, cached) try: @@ -204,8 +204,10 @@ def predict( :param output_format: ``"spans"`` for character spans. :returns: List of spans. """ - if output_format != "spans": - raise ValueError("LLMDetector only supports 'spans' output_format.") + if output_format not in ["tokens", "spans"]: + raise ValueError( + f"LLMDetector doesn't support '{output_format}' format. Use 'tokens' or 'spans'" + ) # Use PromptUtils to format the context and question full_prompt = PromptUtils.format_context(context, question, self.lang) return self._predict(full_prompt, answer) @@ -218,8 +220,10 @@ def predict_prompt(self, prompt: str, answer: str, output_format: str = "spans") :param output_format: ``"spans"`` for character spans. :returns: List of spans. """ - if output_format != "spans": - raise ValueError("LLMDetector only supports 'spans' output_format.") + if output_format not in ["tokens", "spans"]: + raise ValueError( + f"LLMDetector doesn't support '{output_format}' format. Use 'tokens' or 'spans'" + ) return self._predict(prompt, answer) def predict_prompt_batch( @@ -232,8 +236,10 @@ def predict_prompt_batch( :param output_format: ``"spans"`` for character spans. :returns: List of spans. """ - if output_format != "spans": - raise ValueError("LLMDetector only supports 'spans' output_format.") + if output_format not in ["tokens", "spans"]: + raise ValueError( + f"LLMDetector doesn't support '{output_format}' format. Use 'tokens' or 'spans'" + ) with ThreadPoolExecutor(max_workers=30) as pool: futs = [pool.submit(self._predict, p, a) for p, a in zip(prompts, answers)] diff --git a/lettucedetect/detectors/rag_fact_checker.py b/lettucedetect/detectors/rag_fact_checker.py new file mode 100644 index 0000000..fb6cb73 --- /dev/null +++ b/lettucedetect/detectors/rag_fact_checker.py @@ -0,0 +1,223 @@ +"""Simple RAGFactChecker detector wrapper for lettuceDetect factory pattern.""" + +from typing import Any, Dict, List + +from lettucedetect.detectors.base import BaseDetector + + +class RAGFactCheckerDetector(BaseDetector): + """Simple wrapper around RAGFactChecker for lettuceDetect's factory pattern. + + This provides a minimal adapter between lettuceDetect's detector interface + and our clean RAGFactChecker wrapper. + """ + + def __init__( + self, + openai_api_key: str = None, + model: str = "gpt-4o", + base_url: str = None, + temperature: float = 0.0, + **kwargs, + ): + """Initialize the RAGFactChecker detector. + + :param openai_api_key: OpenAI API key + :param model: OpenAI model to use (default: "gpt-4o") + :param base_url: Optional base URL for API (e.g., "http://localhost:1234/v1" for local servers) + :param temperature: Temperature for model sampling (default: 0.0 for deterministic outputs) + :param kwargs: Additional arguments (ignored for simplicity) + :return: RAGFactChecker instance + """ + from lettucedetect.ragfactchecker import RAGFactChecker + + # Use our simple, clean wrapper internally + self.rag = RAGFactChecker( + openai_api_key=openai_api_key, model=model, base_url=base_url, temperature=temperature + ) + + def predict( + self, + context: List[str], + answer: str, + question: str = None, + output_format: str = "tokens", + **kwargs, + ) -> List[Dict[str, Any]] | Dict[str, Any]: + """Predict hallucinations using RAGFactChecker. + + :param context: List of context documents + :param answer: Answer text to check for hallucinations + :param question: Question (optional) + :param output_format: "tokens", "spans", or "detailed" + :param kwargs: Additional arguments + + :return: List of predictions in lettuceDetect format, or dict for detailed format + """ + if output_format not in ["tokens", "spans", "detailed"]: + raise ValueError( + f"Invalid output format '{output_format}'. " + "RAGFactChecker supports 'tokens', 'spans', or 'detailed'" + ) + + # Use our simple wrapper's detection method + result = self.rag.detect_hallucinations(context, answer, question) + + # Convert to lettuceDetect's expected format + if output_format == "detailed": + return { + "spans": self._convert_to_spans(answer, result), + "triplets": { + "answer": result.get("answer_triplets", []), + "context": result.get("context_triplets", []), + "hallucinated": result.get("hallucinated_triplets", []), + }, + "fact_check_results": result.get("fact_check_results", {}), + } + elif output_format == "spans": + return self._convert_to_spans(answer, result) + else: # tokens + return self._convert_to_tokens(answer, result) + + def predict_prompt( + self, prompt: str, answer: str, output_format: str = "tokens" + ) -> List[Dict[str, Any]]: + """Predict using a single prompt string as context.""" + return self.predict([prompt], answer, output_format=output_format) + + def predict_prompt_batch( + self, prompts: List[str], answers: List[str], output_format: str = "tokens" + ) -> List[List[Dict[str, Any]]]: + """Batch prediction using RAGFactChecker's batch processing.""" + if len(prompts) != len(answers): + raise ValueError("Number of prompts must match number of answers") + + contexts = [[prompt] for prompt in prompts] # Convert prompts to context lists + rag_results = self.rag.detect_hallucinations_batch(contexts, answers) + + # Convert each result to lettuceDetect format + converted_results = [] + for i, (answer, rag_result) in enumerate(zip(answers, rag_results)): + if output_format == "tokens": + converted = self._convert_to_tokens(answer, rag_result) + elif output_format == "spans": + converted = self._convert_to_spans(answer, rag_result) + else: + raise ValueError(f"Unknown output format: {output_format}") + converted_results.append(converted) + + return converted_results + + def _convert_to_tokens(self, answer: str, rag_result: Dict[str, Any]) -> List[Dict[str, Any]]: + """Convert RAGFactChecker result to token format.""" + tokens = answer.split() + hallucinated_triplets = rag_result.get("hallucinated_triplets", []) + + token_predictions = [] + for i, token in enumerate(tokens): + # Simple check if token appears in any hallucinated triplet + is_hallucinated = any( + token.lower() in " ".join(triplet).lower() for triplet in hallucinated_triplets + ) + + token_predictions.append( + { + "token": token, + "pred": 1 if is_hallucinated else 0, + "prob": 0.9 if is_hallucinated else 0.1, + } + ) + + return token_predictions + + def _convert_to_spans(self, answer: str, rag_result: Dict[str, Any]) -> List[Dict[str, Any]]: + """Convert RAGFactChecker result to span format with improved triplet matching.""" + spans = [] + hallucinated_triplets = rag_result.get("hallucinated_triplets", []) + + for triplet in hallucinated_triplets: + if len(triplet) < 3: + continue + + # Try different patterns to find triplet elements in text + patterns = [ + f"{triplet[0]} {triplet[1]} {triplet[2]}", # Full triplet phrase + f"{triplet[0]} {triplet[2]}", # Subject + object + triplet[2], # Object (often contains the hallucination) + triplet[0], # Subject + triplet[1], # Predicate + ] + + found_span = False + for pattern in patterns: + if not pattern or not pattern.strip(): + continue + + # Try exact match first, then case-insensitive + start = answer.find(pattern) + if start == -1: + start = answer.lower().find(pattern.lower()) + if start != -1: + # Get the actual text from the answer with correct case + pattern = answer[start : start + len(pattern)] + + if start != -1: + spans.append( + { + "start": start, + "end": start + len(pattern), + "text": pattern, + "confidence": 0.9, + "triplet": triplet, # Include source triplet for transparency + } + ) + found_span = True + break + + # If no pattern matched, try individual words from the triplet + if not found_span: + for element in triplet: + if element and element.strip() and len(element) > 3: # Skip short words + start = answer.lower().find(element.lower()) + if start != -1: + actual_text = answer[start : start + len(element)] + spans.append( + { + "start": start, + "end": start + len(element), + "text": actual_text, + "confidence": 0.7, # Lower confidence for partial matches + "triplet": triplet, + } + ) + break + + # Merge overlapping spans + return self._merge_overlapping_spans(spans) + + def _merge_overlapping_spans(self, spans: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Merge overlapping spans to avoid duplicates.""" + if not spans: + return spans + + # Sort spans by start position + sorted_spans = sorted(spans, key=lambda x: x["start"]) + merged = [sorted_spans[0]] + + for current in sorted_spans[1:]: + last = merged[-1] + + # Check if spans overlap + if current["start"] <= last["end"]: + # Merge spans - extend the end and combine triplets + merged[-1] = { + "start": last["start"], + "end": max(last["end"], current["end"]), + "text": last["text"], # Keep original text + "confidence": max(last["confidence"], current["confidence"]), + "triplet": last.get("triplet", current.get("triplet")), + } + else: + merged.append(current) + + return merged diff --git a/lettucedetect/detectors/transformer.py b/lettucedetect/detectors/transformer.py index e289744..3ef8c23 100644 --- a/lettucedetect/detectors/transformer.py +++ b/lettucedetect/detectors/transformer.py @@ -43,6 +43,11 @@ def _predict(self, prompt: str, answer: str, output_format: str) -> list: :param answer: The answer string. :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. """ + if output_format not in ["tokens", "spans"]: + raise ValueError( + f"TransformerDetector doesn't support '{output_format}' format. " + "Use 'tokens' or 'spans'" + ) # Use the shared tokenization logic from HallucinationDataset encoding, _, offsets, answer_start_token = HallucinationDataset.prepare_tokenized_input( self.tokenizer, prompt, answer, self.max_length diff --git a/lettucedetect/integrations/__init__.py b/lettucedetect/integrations/__init__.py new file mode 100644 index 0000000..de7bde5 --- /dev/null +++ b/lettucedetect/integrations/__init__.py @@ -0,0 +1,5 @@ +"""LettuceDetect integrations with popular frameworks. + +This package provides clean, professional integrations between LettuceDetect +and popular AI/ML frameworks for seamless hallucination detection. +""" diff --git a/lettucedetect/integrations/elysia/README.md b/lettucedetect/integrations/elysia/README.md new file mode 100644 index 0000000..894f324 --- /dev/null +++ b/lettucedetect/integrations/elysia/README.md @@ -0,0 +1,42 @@ +# LettuceDetect + Elysia Integration + +Automatic hallucination detection for Elysia AI decision trees. + +## Installation + +```bash +pip install lettucedetect elysia-ai +``` + +## Usage + +```python +from elysia import Tree +from lettucedetect.integrations.elysia import detect_hallucinations + +# Create tree with hallucination detection +tree = Tree() +tree.add_tool(detect_hallucinations) + +# The AI can now automatically validate responses +response = tree(""" +Context: Python was created by Guido van Rossum in 1991. +Question: When was Python created? +Please answer and verify your response for accuracy. +""") +``` + +## What It Does + +The `detect_hallucinations` tool automatically: +- ✅ Analyzes AI responses against provided context +- ✅ Identifies unsupported claims and factual errors +- ✅ Provides confidence scores and exact text spans +- ✅ Guides the AI to self-correct when needed + +## Tool Details + +**detect_hallucinations**: Main hallucination detection tool +- Compares generated answers against source context +- Returns structured data about problematic spans +- Supports multiple detection methods (transformer, LLM, fact-checker) \ No newline at end of file diff --git a/lettucedetect/integrations/elysia/__init__.py b/lettucedetect/integrations/elysia/__init__.py new file mode 100644 index 0000000..eddb65e --- /dev/null +++ b/lettucedetect/integrations/elysia/__init__.py @@ -0,0 +1,9 @@ +"""LettuceDetect integration for Elysia. + +This integration provides hallucination detection tools that can be used +directly in Elysia decision trees for automatic quality control of AI responses. +""" + +from .tools import detect_hallucinations + +__all__ = ["detect_hallucinations"] diff --git a/lettucedetect/integrations/elysia/example.py b/lettucedetect/integrations/elysia/example.py new file mode 100644 index 0000000..096b285 --- /dev/null +++ b/lettucedetect/integrations/elysia/example.py @@ -0,0 +1,15 @@ +"""Example of using LettuceDetect with Elysia for automatic hallucination detection.""" + +from elysia import Tree + +from lettucedetect.integrations.elysia import detect_hallucinations + +# Create an Elysia tree with hallucination detection capabilities +tree = Tree() + +# Add LettuceDetect tools to the tree +tree.add_tool(detect_hallucinations) + +tree( + "How many data they generated in Kovacs et al. 2025? Please answer and verify your response for credibility." +) diff --git a/lettucedetect/integrations/elysia/tools.py b/lettucedetect/integrations/elysia/tools.py new file mode 100644 index 0000000..042555d --- /dev/null +++ b/lettucedetect/integrations/elysia/tools.py @@ -0,0 +1,110 @@ +"""LettuceDetect integration tools for Elysia.""" + +from typing import List, Optional + +from elysia import tool + +from lettucedetect import HallucinationDetector + + +@tool +async def detect_hallucinations( + context: List[str], + answer: str, + question: Optional[str] = None, +): + """Verify AI-generated answers by comparing them against source context by using detecting hallucinations. + + This tool analyzes whether statements in an answer are supported by the provided context, + identifying specific spans of text that may be hallucinated or unsupported. It uses + advanced NLP models to perform token-level analysis and provides detailed feedback + about problematic content. + + Args: + context: List of source documents or passages that should support the answer. + Each string represents a separate context document or paragraph. + These are the "ground truth" sources the answer should be based on. + answer: The AI-generated response to analyze for potential hallucinations. + This is the text that will be checked against the context. + question: Optional original question that was asked. Providing this improves + detection accuracy by understanding what information was requested. + + This tool performs the following analysis: + 1. Tokenizes the answer and compares each segment against the context + 2. Identifies spans that are not supported by any context document + 3. Assigns confidence scores to problematic spans + 4. Returns structured results with exact character positions + + The tool will identify various types of hallucinations: + - Factual errors (wrong dates, names, numbers) + - Unsupported claims not present in context + - Contradictions to the provided information + - Invented details not mentioned in sources + + Always use this tool when you need to: + - Verify AI responses against source documents in RAG systems + - Implement quality control for generated content + - Build fact-checking pipelines + - Ensure accuracy in knowledge-based applications + - Validate information before presenting to users + + Example scenario: + Context: ["Python was created in 1991 by Guido van Rossum", "It's known for readable syntax"] + Answer: "Python was created in 1985 by James Gosling and is known for complex syntax" + + This tool would identify: + - "1985" as hallucinated (should be 1991) + - "James Gosling" as hallucinated (should be Guido van Rossum) + - "complex syntax" as hallucinated (context says readable syntax) + + """ + try: + # Initialize detector with transformer method + detector = HallucinationDetector( + method="transformer", model_path="KRLabsOrg/lettucedect-base-modernbert-en-v1" + ) + + # Perform hallucination detection + spans = detector.predict( + context=context, answer=answer, question=question, output_format="spans" + ) + + # Calculate overall metrics + has_issues = len(spans) > 0 + max_confidence = max([span.get("confidence", 0) for span in spans], default=0) + + # Create structured result + result = { + "has_issues": has_issues, + "confidence": max_confidence, + "issue_count": len(spans), + "spans": spans, + } + + # Yield structured data for the AI agent + yield result + + # Create human-readable summary + if has_issues: + issue_details = [] + for span in spans[:5]: # Show up to 5 examples + text = span.get("text", "unknown") + conf = span.get("confidence", 0) + start = span.get("start", 0) + end = span.get("end", 0) + issue_details.append(f"'{text}' at position {start}-{end} (confidence: {conf:.2f})") + + summary = f"Detected {len(spans)} potential hallucination(s) in the answer. " + summary += f"Most problematic spans: {', '.join(issue_details)}. " + summary += ( + "The AI should revise these unsupported claims or provide additional context." + ) + else: + summary = "No hallucinations detected. The answer appears to be well-supported by the provided context." + + yield summary + + except Exception as e: + error_msg = f"Hallucination detection failed: {e!s}" + yield {"error": True, "message": str(e)} + yield error_msg diff --git a/lettucedetect/integrations/langchain/README.md b/lettucedetect/integrations/langchain/README.md new file mode 100644 index 0000000..75864c9 --- /dev/null +++ b/lettucedetect/integrations/langchain/README.md @@ -0,0 +1,95 @@ +# LettuceDetect + LangChain Integration + +Real-time hallucination detection for RAG pipelines. + +## Installation + +```bash +pip install lettucedetect +pip install langchain langchain-openai langchain-community langchain-chroma +export OPENAI_API_KEY=your_key +``` + +## Usage + +```python +from langchain.chains import RetrievalQA +from langchain.text_splitter import CharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from lettucedetect.integrations.langchain import stream_with_detection + +# Set up your RAG pipeline +documents = ["Your documents here..."] +embeddings = OpenAIEmbeddings() +text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=0) +docs = text_splitter.create_documents(documents) +vectorstore = Chroma.from_documents(docs, embeddings) + +# Create streaming RAG chain +llm = ChatOpenAI(model="gpt-4o-mini", streaming=True) +chain = RetrievalQA.from_chain_type( + llm=llm, + chain_type="stuff", + retriever=vectorstore.as_retriever(search_kwargs={"k": 3}) +) + +# Get context and stream with detection +question = "Your question here" +context = [doc.page_content for doc in vectorstore.similarity_search(question, k=3)] + +# Stream tokens and hallucination detection in real-time +for event in stream_with_detection(chain, {"query": question}, context, check_every=10): + if event["type"] == "token": + print(event["content"], end="", flush=True) # Stream response + elif event["type"] == "detection" and event["has_issues"]: + print(f"\nHallucination detected: {event['issue_count']} issues") + # Handle detection - log, alert, stop generation, etc. +``` + +## Direct Callback Usage + +For more control, use `LettuceStreamingCallback` directly: + +```python +from lettucedetect.integrations.langchain import LettuceStreamingCallback + +# Create callback with your settings +callback = LettuceStreamingCallback( + context=context, + question=question, + check_every=10, + method="transformer" # or "rag_fact_checker" +) + +# Use with any LangChain chain +result = chain.invoke({"query": question}, config={"callbacks": [callback]}) + +# Stream events as they arrive +for event in callback.stream_events(): + if event["type"] == "token": + print(event["content"], end="") + elif event["type"] == "detection": + handle_detection(event) +``` + +## What You Get + +**Token Events**: Real-time text as it's generated +**Detection Events**: Hallucination analysis with confidence scores and exact spans + +Each detection event includes: +- `has_issues`: Boolean if hallucinations found +- `issue_count`: Number of problematic spans +- `confidence`: Detection confidence (0-1) +- `spans`: Array of problematic text spans with positions + +## Live Demo + +See it in action: +```bash +streamlit run lettucedetect/integrations/langchain/examples/streamlit_app.py +python lettucedetect/integrations/langchain/examples/rag_example.py +``` + +Perfect for building streaming chat apps, real-time APIs, and production RAG systems with automatic quality control. \ No newline at end of file diff --git a/lettucedetect/integrations/langchain/__init__.py b/lettucedetect/integrations/langchain/__init__.py new file mode 100644 index 0000000..3e78ef4 --- /dev/null +++ b/lettucedetect/integrations/langchain/__init__.py @@ -0,0 +1,39 @@ +"""LangChain integration for LettuceDetect hallucination detection. + +This module provides a clean, minimal callback for integrating LettuceDetect +with LangChain applications. The callback automatically detects hallucinations +in LLM responses when used with retrieval chains. + +Example usage: + + from integrations.langchain import LettuceDetectCallback, detect_in_chain + from langchain.chains import RetrievalQA + + # Basic usage + callback = LettuceDetectCallback(verbose=True) + result = chain.run("Your question", callbacks=[callback]) + + if callback.has_issues(): + print("Potential hallucinations detected") + + # Or use convenience function + result = detect_in_chain(chain, "Your question") + print(f"Answer: {result['answer']}") + print(f"Issues: {result['has_issues']}") +""" + +from .callbacks import ( + LettuceDetectCallback, + LettuceStreamingCallback, + detect_in_chain, + stream_with_detection, +) + +__all__ = [ + "LettuceDetectCallback", + "LettuceStreamingCallback", + "detect_in_chain", + "stream_with_detection", +] + +__version__ = "1.0.0" diff --git a/lettucedetect/integrations/langchain/callbacks.py b/lettucedetect/integrations/langchain/callbacks.py new file mode 100644 index 0000000..35bb2d9 --- /dev/null +++ b/lettucedetect/integrations/langchain/callbacks.py @@ -0,0 +1,409 @@ +"""Clean, minimal LangChain callbacks for LettuceDetect integration.""" + +import threading +from queue import Queue +from typing import Any, Callable, Dict, List, Optional + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import LLMResult +from langchain.schema.document import Document + +from lettucedetect import HallucinationDetector + + +class LettuceDetectCallback(BaseCallbackHandler): + """Simple callback for post-generation hallucination detection. + + Automatically detects hallucinations in LLM responses when used with + retrieval chains or when context is provided manually. + """ + + def __init__( + self, + method: str = "rag_fact_checker", + model_path: Optional[str] = None, + on_result: Optional[Callable[[Dict[str, Any]], None]] = None, + verbose: bool = False, + ): + """Initialize the callback. + + Args: + method: Detection method ("transformer", "llm", "rag_fact_checker") + model_path: Path to model (for transformer method) + on_result: Optional function to handle detection results + verbose: Whether to print results + + """ + super().__init__() + self.detector = HallucinationDetector(method=method, model_path=model_path) + self.on_result = on_result + self.verbose = verbose + + # State + self.context: List[str] = [] + self.question: Optional[str] = None + self.results: List[Dict[str, Any]] = [] + + def set_context(self, context: List[str]) -> None: + """Manually set context documents.""" + self.context = context + + def set_question(self, question: str) -> None: + """Manually set the question.""" + self.question = question + + def on_retriever_end(self, documents: List[Document], **kwargs: Any) -> None: + """Store retrieved context.""" + self.context = [doc.page_content for doc in documents] + + def on_chain_start( + self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + ) -> None: + """Extract question from chain inputs.""" + for key in ["question", "query", "input"]: + if key in inputs: + self.question = inputs[key] + break + + def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: + """Run hallucination detection on LLM response.""" + if not self.context or not response.generations: + return + + for generation in response.generations: + if not generation: + continue + + text = generation[0].text + if not text.strip(): + continue + + try: + spans = self.detector.predict( + context=self.context, answer=text, question=self.question, output_format="spans" + ) + + result = { + "text": text, + "question": self.question, + "context": self.context.copy(), + "has_issues": len(spans) > 0, + "confidence": max([s.get("confidence", 0) for s in spans], default=0), + "spans": spans, + "issue_count": len(spans), + } + + self.results.append(result) + + if self.verbose: + status = "ISSUES DETECTED" if result["has_issues"] else "CLEAN" + print(f"LettuceDetect: {status} (confidence: {result['confidence']:.3f})") + + if self.on_result: + self.on_result(result) + + except Exception as e: + if self.verbose: + print(f"LettuceDetect: Detection error: {e}") + + def get_results(self) -> List[Dict[str, Any]]: + """Get all detection results.""" + return self.results.copy() + + def get_last_result(self) -> Optional[Dict[str, Any]]: + """Get the most recent detection result.""" + return self.results[-1] if self.results else None + + def has_issues(self) -> bool: + """Check if any results had issues.""" + return any(r["has_issues"] for r in self.results) + + def reset(self) -> None: + """Reset callback state.""" + self.context = [] + self.question = None + self.results = [] + + +class LettuceStreamingCallback(BaseCallbackHandler): + """Real-time hallucination detection with JSON event streaming. + + Provides true streaming of both tokens and detection results through + a queue-based system that works with any LangChain component. + """ + + def __init__( + self, + method: str = "transformer", + model_path: Optional[str] = "output/hallucination_detection_ettin_17m", + context: Optional[List[str]] = None, + question: Optional[str] = None, + check_every: int = 10, + on_detection: Optional[Callable[[Dict[str, Any]], None]] = None, + verbose: bool = False, + ): + """Initialize streaming callback. + + Args: + method: Detection method + model_path: Path to model (for transformer method) + context: Context documents for detection + question: Question being answered + check_every: Run detection every N tokens + on_detection: Function called when detection runs + verbose: Whether to print detection results + + """ + super().__init__() + self.detector = HallucinationDetector(method=method, model_path=model_path) + self.context = context or [] + self.question = question + self.check_every = check_every + self.on_detection = on_detection + self.verbose = verbose + + # Streaming state + self.accumulated_text = "" + self.token_count = 0 + + # Queue for true streaming of JSON events + self.event_queue = Queue() + + def set_context(self, context: List[str]) -> None: + """Set context documents.""" + self.context = context + + def set_question(self, question: str) -> None: + """Set the question being answered.""" + self.question = question + + def on_llm_start(self, *args, **kwargs): + """Reset state when streaming starts.""" + self.accumulated_text = "" + self.token_count = 0 + + def on_chat_model_start(self, *args, **kwargs): + """Handle chat model start for newer LangChain versions.""" + self.on_llm_start(*args, **kwargs) + + def on_llm_new_token(self, token: str, **kwargs): + """Process new token and run detection periodically.""" + self.accumulated_text += token + self.token_count += 1 + + # Stream token event immediately + self.event_queue.put( + {"type": "token", "content": token, "position": len(self.accumulated_text)} + ) + + # Run detection every N tokens + if ( + self.token_count >= self.check_every + and len(self.accumulated_text.strip()) > 20 + and self.context + ): + try: + # Run detection on accumulated text + spans = self.detector.predict( + context=self.context, + answer=self.accumulated_text, + question=self.question, + output_format="spans", + ) + + # Create detection result + result = { + "text": self.accumulated_text, + "has_issues": len(spans) > 0, + "spans": spans, + "confidence": max([s.get("confidence", 0) for s in spans], default=0), + "issue_count": len(spans), + "token_count": len(self.accumulated_text.split()), + "is_incremental": True, + } + + # Stream detection event immediately + self.event_queue.put( + { + "type": "detection", + "has_issues": len(spans) > 0, + "spans": spans, + "confidence": max([s.get("confidence", 0) for s in spans], default=0), + "issue_count": len(spans), + "text_length": len(self.accumulated_text), + "is_incremental": True, + } + ) + + # Call user handler + if self.on_detection: + self.on_detection(result) + + # Verbose output + if self.verbose and result["has_issues"]: + print(f"Real-time detection: {result['issue_count']} issues found") + + # Reset token counter + self.token_count = 0 + + except Exception as e: + if self.verbose: + print(f"Streaming detection error: {e}") + + def on_llm_end(self, response, **kwargs): + """Run final detection on complete response.""" + if self.accumulated_text and self.context: + try: + spans = self.detector.predict( + context=self.context, + answer=self.accumulated_text, + question=self.question, + output_format="spans", + ) + + final_result = { + "text": self.accumulated_text, + "has_issues": len(spans) > 0, + "spans": spans, + "confidence": max([s.get("confidence", 0) for s in spans], default=0), + "issue_count": len(spans), + "token_count": len(self.accumulated_text.split()), + "is_final": True, + } + + # Stream final detection event + self.event_queue.put( + { + "type": "detection", + "has_issues": len(spans) > 0, + "spans": spans, + "confidence": max([s.get("confidence", 0) for s in spans], default=0), + "issue_count": len(spans), + "text_length": len(self.accumulated_text), + "is_final": True, + } + ) + + if self.on_detection: + self.on_detection(final_result) + + if self.verbose: + status = "Issues found" if final_result["has_issues"] else "Clean" + print(f"Final detection: {status}") + + except Exception as e: + if self.verbose: + print(f"Final detection error: {e}") + + # Signal completion + self.event_queue.put(None) # End signal + + def on_chat_model_end(self, response, **kwargs): + """Handle chat model end for newer LangChain versions.""" + self.on_llm_end(response, **kwargs) + + def stream_events(self): + """Generator that yields JSON events as they arrive. + + Yields events with types: + - "token": Individual tokens as they arrive + - "detection": Hallucination detection results + + This allows developers to: + - Stream JSON events to clients in real-time + - Handle tokens and detections immediately + - Build real-time UIs and APIs + """ + while True: + event = self.event_queue.get() + if event is None: # End signal + break + yield event + + +def detect_in_chain( + chain, query: str, context: Optional[List[str]] = None, **kwargs +) -> Dict[str, Any]: + """Convenience function to run a chain with automatic hallucination detection. + + Args: + chain: LangChain chain to execute + query: Query/question to ask + context: Optional context documents (if not using retrieval) + **kwargs: Additional arguments passed to chain + + Returns: + Dictionary with chain result and detection info + + """ + callback = LettuceDetectCallback(**kwargs) + + if context: + callback.set_context(context) + callback.set_question(query) + + # Run chain with callback + chain_result = chain.invoke({"query": query}, config={"callbacks": [callback]}) + result = chain_result.get("result", "") + + detection_result = callback.get_last_result() + + return { + "answer": result, + "detection": detection_result, + "has_issues": detection_result["has_issues"] if detection_result else False, + } + + +def stream_with_detection(chain_or_llm, input_data, context, **callback_kwargs): + """Stream JSON events from any LangChain chain/LLM with hallucination detection. + + Works with RetrievalQA, ConversationChain, raw LLMs, or any LangChain component. + + Args: + chain_or_llm: Any LangChain chain or LLM + input_data: Input for the chain (query string, messages, etc.) + context: Context documents for hallucination detection + **callback_kwargs: Additional arguments for LettuceStreamingCallback + + Yields: + dict: JSON events with "type": "token" or "detection" + + Example: + # With RAG chain + chain = RetrievalQA.from_llm(llm, retriever) + for event in stream_with_detection(chain, "Your question", context): + if event["type"] == "token": + await websocket.send_json(event) + elif event["type"] == "detection": + print(f"Detection: {event['has_issues']}") + + """ + callback = LettuceStreamingCallback(context=context, **callback_kwargs) + + # Start chain/LLM in background thread + def run_generation(): + try: + if hasattr(chain_or_llm, "invoke"): + # Modern LangChain interface + chain_or_llm.invoke(input_data, config={"callbacks": [callback]}) + elif hasattr(chain_or_llm, "run"): + # Legacy chain interface + chain_or_llm.run(input_data, callbacks=[callback]) + else: + # Try direct call + chain_or_llm(input_data, callbacks=[callback]) + except Exception as e: + # Put error event and complete + callback.event_queue.put({"type": "error", "message": str(e)}) + callback.event_queue.put(None) + + thread = threading.Thread(target=run_generation) + thread.start() + + # Stream events as they arrive + try: + for event in callback.stream_events(): + yield event + finally: + thread.join() # Ensure thread completes diff --git a/lettucedetect/integrations/langchain/examples/rag_example.py b/lettucedetect/integrations/langchain/examples/rag_example.py new file mode 100644 index 0000000..5bfae75 --- /dev/null +++ b/lettucedetect/integrations/langchain/examples/rag_example.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +"""Professional LettuceDetect + LangChain RAG example. + +Demonstrates automatic hallucination detection in a retrieval-augmented +generation pipeline using clean, production-ready code. + +Requirements: +- pip install -r lettucedetect/integrations/langchain/requirements.txt +- export OPENAI_API_KEY=your_key +""" + +import os + +from langchain.chains import RetrievalQA + +# LangChain imports +from langchain.text_splitter import CharacterTextSplitter +from langchain_community.vectorstores import Chroma +from langchain_openai import ChatOpenAI, OpenAI, OpenAIEmbeddings + +# LettuceDetect integration +from lettucedetect.integrations.langchain.callbacks import ( + LettuceDetectCallback, + detect_in_chain, + stream_with_detection, +) + +# Sample documents for demonstration +SAMPLE_DOCUMENTS = [ + "The Pacific Ocean is the largest ocean on Earth, covering about 46% of the water surface.", + "Python was created by Guido van Rossum and first released in 1991.", + "Machine learning is a subset of artificial intelligence focused on data-driven predictions.", + "The human brain contains approximately 86 billion neurons.", + "Photosynthesis converts light energy into chemical energy in plants.", +] + + +def create_rag_chain(): + """Create a simple RAG chain with vector retrieval.""" + # Create embeddings and vector store + embeddings = OpenAIEmbeddings() + + # Split documents and create vector store + text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0) + docs = text_splitter.create_documents(SAMPLE_DOCUMENTS) + vectorstore = Chroma.from_documents(docs, embeddings) + + # Create retrieval chain + llm = OpenAI(model="gpt-4o-mini") + chain = RetrievalQA.from_chain_type( + llm=llm, + chain_type="stuff", + retriever=vectorstore.as_retriever(search_kwargs={"k": 2}), + return_source_documents=False, + ) + + return chain + + +def example_basic_rag_detection(): + """Basic RAG with post-generation hallucination detection.""" + print("Basic RAG + Detection Example") + print("-" * 40) + + chain = create_rag_chain() + + # Questions to test + questions = [ + "What is the Pacific Ocean?", # Should be clean + "Who created Python and when was it invented?", # Should be clean + "How does Python relate to ocean exploration?", # Likely hallucinated + ] + + for question in questions: + print(f"Q: {question}") + + # Use convenience function for simple post-generation detection + result = detect_in_chain(chain, question, verbose=True) + + print(f"A: {result['answer']}") + + if result["has_issues"]: + detection = result["detection"] + print(f"🚨 Issues detected: {detection['issue_count']} spans") + print(f"Max confidence: {detection['confidence']:.3f}") + else: + print("✅ No issues detected") + + print() + + +def example_rag_streaming_detection(): + """RAG with real-time streaming detection - simplified to show working approach.""" + print("RAG + Real-time Streaming Detection Example") + print("-" * 40) + print("Shows structured JSON events during streaming") + print() + + # Setup RAG chain + embeddings = OpenAIEmbeddings() + text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0) + docs = text_splitter.create_documents(SAMPLE_DOCUMENTS) + vectorstore = Chroma.from_documents(docs, embeddings) + + llm = ChatOpenAI(model="gpt-4o-mini", streaming=True) + chain = RetrievalQA.from_chain_type( + llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever(search_kwargs={"k": 2}) + ) + + question = "How does Python relate to ocean exploration and marine biology?" + context = [doc.page_content for doc in vectorstore.similarity_search(question, k=2)] + + print(f"Q: {question}") + print(f"Context: {context[0][:50]}...") + print() + print("Streaming Events:") + print("-" * 18) + + # Use the working streaming approach + event_count = 0 + for event in stream_with_detection(chain, {"query": question}, context, check_every=8): + event_count += 1 + if event["type"] == "token": + print(event["content"], end="", flush=True) + elif event["type"] == "detection" and event["has_issues"]: + print( + f"\n[Detection {event_count}: {event['issue_count']} issues, confidence: {event['confidence']:.3f}]", + end="", + flush=True, + ) + + print("\n") + print(f"Total events processed: {event_count}") + + +def example_simple_json_streaming(): + """Simple example showing TRUE JSON streaming - perfect for API developers.""" + print("Simple JSON Streaming Example") + print("-" * 35) + print("Shows real-time JSON events - exactly what API developers need!") + print() + + # Setup simple RAG chain + embeddings = OpenAIEmbeddings() + text_splitter = CharacterTextSplitter(chunk_size=200, chunk_overlap=0) + docs = text_splitter.create_documents(SAMPLE_DOCUMENTS) + vectorstore = Chroma.from_documents(docs, embeddings) + + llm = ChatOpenAI(model="gpt-4o-mini", streaming=True) + chain = RetrievalQA.from_chain_type( + llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever(search_kwargs={"k": 2}) + ) + + question = "How does Python relate to ocean exploration?" + context = [doc.page_content for doc in vectorstore.similarity_search(question, k=2)] + + print(f"Q: {question}") + print(f"Context: {context[0][:50]}...") + print() + print("JSON Events Stream:") + print("-" * 18) + + # THIS IS THE MAGIC - Stream JSON events in real-time! + for event in stream_with_detection(chain, {"query": question}, context, check_every=5): + # Each event is a JSON-serializable dict + import json + + print(json.dumps(event)) + + # In your API: + # if event["type"] == "token": + # await websocket.send_json(event) + # elif event["type"] == "detection" and event["has_issues"]: + # await websocket.send_json({"alert": "hallucination_detected", "spans": event["spans"]}) + + print() + print("Perfect for:") + print(" - FastAPI streaming responses") + print(" - WebSocket real-time chat") + print(" - Server-sent events (SSE)") + print(" - Any API that needs live updates") + + +def example_with_manual_context(): + """Example providing context manually (without retrieval).""" + print("Manual Context Example") + print("-" * 40) + + # Simple LLM without retrieval + llm = OpenAI(model="gpt-4o-mini") + + # Manual context + context = [ + "Python is a programming language created by Guido van Rossum in 1991.", + "It is known for its simple syntax and readability.", + ] + + callback = LettuceDetectCallback(verbose=True) + callback.set_context(context) + callback.set_question("What is Python?") + + # Direct LLM call + response = llm.generate(["What is Python?"], callbacks=[callback]) + answer = response.generations[0][0].text + + print(f"A: {answer}") + + result = callback.get_last_result() + if result: + print(f"Detection: {'Issues found' if result['has_issues'] else 'Clean'}") + + +def main(): + """Run all examples.""" + if not os.getenv("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY environment variable required") + return + + try: + example_basic_rag_detection() + print("=" * 60) + example_simple_json_streaming() # TRUE JSON streaming! + print("=" * 60) + example_rag_streaming_detection() # Detailed streaming analysis + print("=" * 60) + example_with_manual_context() + + except Exception as e: + print(f"Error: {e}") + print( + "Make sure you have: pip install -r lettucedetect/integrations/langchain/requirements.txt" + ) + + +if __name__ == "__main__": + main() diff --git a/lettucedetect/integrations/langchain/examples/streamlit_app.py b/lettucedetect/integrations/langchain/examples/streamlit_app.py new file mode 100644 index 0000000..2513a5f --- /dev/null +++ b/lettucedetect/integrations/langchain/examples/streamlit_app.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python3 +"""Clean Streamlit demo for LettuceDetect + LangChain real-time detection. + +Run with: streamlit run lettucedetect/integrations/langchain/examples/streamlit_app.py + +Requirements: +- pip install streamlit langchain langchain-openai lettucedetect +- export OPENAI_API_KEY=your_key +""" + +import os +import time + +import streamlit as st +import streamlit.components.v1 as components + +# LangChain imports with compatibility handling +try: + from langchain_openai import ChatOpenAI + + try: + ChatOpenAI.model_rebuild() + except Exception: + pass +except ImportError: + ChatOpenAI = None + +from langchain.callbacks.base import BaseCallbackHandler +from langchain.schema import HumanMessage + +from lettucedetect import HallucinationDetector + +# LettuceDetect integration +from lettucedetect.integrations.langchain.callbacks import LettuceStreamingCallback + + +def create_interactive_text(text: str, spans: list[dict]) -> str: + """Create clean interactive HTML with highlighting (matching original demo style).""" + html_text = text + + # Apply highlighting (reverse order to preserve indices) + for span in sorted(spans, key=lambda x: x.get("start", 0), reverse=True): + start = span.get("start", 0) + end = span.get("end", 0) + confidence = span.get("confidence", 0) + + if 0 <= start < end <= len(text): + span_text = text[start:end] + highlighted_span = f'{span_text}' + html_text = html_text[:start] + highlighted_span + html_text[end:] + + return f""" + +
{html_text}
+ """ + + +class StreamlitRealtimeHandler(BaseCallbackHandler): + """Simple handler for real-time streaming with HTML display.""" + + def __init__(self, html_placeholder): + super().__init__() + self.html_placeholder = html_placeholder + self.text = "" + self.spans = [] + + def on_llm_start(self, *args, **kwargs): + self.text = "" + self.spans = [] + self._update_display() + + def on_chat_model_start(self, *args, **kwargs): + self.on_llm_start(*args, **kwargs) + + def on_llm_new_token(self, token: str, **kwargs): + self.text += token + # Update display with current text and any spans + self._update_display() + # sleep for 0.1 seconds + time.sleep(0.1) + + def update_with_detection(self, spans): + """Update display with detection results.""" + self.spans = spans + self._update_display() + + def _update_display(self): + """Update the HTML display with current text and spans.""" + if not self.text.strip(): + html_content = ( + "
Generating response...
" + ) + else: + html_content = create_interactive_text(self.text, self.spans) + + with self.html_placeholder: + components.html(html_content, height=max(200, len(self.text) // 4)) + + +def create_prompt(question: str, context: str) -> str: + """Create prompt from context and question.""" + return f"""Based on the following context, answer the question: + +Context: {context} + +Question: {question} + +Answer based only on the provided context:""" + + +def main(): + """Main Streamlit application - clean and simple like the original demo.""" + st.set_page_config(page_title="LettuceDetect Real-time Demo") + + # Show lettuce detective image like original + st.image( + "https://github.com/KRLabsOrg/LettuceDetect/blob/main/assets/lettuce_detective.png?raw=true", + width=600, + ) + + st.title("Real-time Hallucination Detection") + + # Check requirements + if not os.getenv("OPENAI_API_KEY"): + st.error("OPENAI_API_KEY environment variable required") + st.stop() + + if ChatOpenAI is None: + st.error("langchain-openai not installed") + st.stop() + + # Simple form like original demo + context = st.text_area( + "Context", + "Python is a high-level programming language created by Guido van Rossum in 1991. " + "It is known for its simple, readable syntax and extensive standard library.", + height=100, + ) + + question = st.text_area( + "Question", + "What is Python and who created it?", + height=100, + ) + + # Initialize components + @st.cache_resource + def get_llm(): + return ChatOpenAI(model="gpt-4o-mini", streaming=True) + + @st.cache_resource + def get_detector(): + model_path = "KRLabsOrg/tinylettuce-ettin-17m-en" + if os.path.exists(model_path): + return HallucinationDetector(method="transformer", model_path=model_path) + else: + return HallucinationDetector(method="rag_fact_checker") + + llm = get_llm() + detector = get_detector() + + # Single response area for HTML display + html_placeholder = st.empty() + + # Simple detect button like original + if st.button("Generate with Real-time Detection"): + if not context.strip() or not question.strip(): + st.warning("Please provide both context and question") + return + + # State for real-time detection + final_spans = [] + + def handle_detection(result): + """Handle detection results by passing to output handler.""" + nonlocal final_spans + spans = result.get("spans", []) + + # Pass detection results to the output handler + output_handler.update_with_detection(spans) + + if result.get("is_final", False): + final_spans = spans + + # Create callbacks + detection_callback = LettuceStreamingCallback( + method="transformer", + model_path="KRLabsOrg/tinylettuce-ettin-17m-en", + context=[context], + question=question, + check_every=10, + on_detection=handle_detection, + verbose=False, + ) + + output_handler = StreamlitRealtimeHandler(html_placeholder) + callbacks = [detection_callback, output_handler] + + # Generate response + try: + messages = [HumanMessage(content=create_prompt(question, context))] + + with st.spinner("Generating..."): + llm.invoke(messages, config={"callbacks": callbacks}) + + # Show final status message + issue_count = len(final_spans) + if issue_count > 0: + st.warning( + f"⚠️ {issue_count} potential issue{'s' if issue_count > 1 else ''} detected" + ) + else: + st.success("✅ Response appears clean") + + except Exception as e: + st.error(f"Error: {e}") + + +if __name__ == "__main__": + main() diff --git a/lettucedetect/integrations/langchain/requirements.txt b/lettucedetect/integrations/langchain/requirements.txt new file mode 100644 index 0000000..361a790 --- /dev/null +++ b/lettucedetect/integrations/langchain/requirements.txt @@ -0,0 +1,10 @@ +# LangChain Integration Requirements for LettuceDetect +langchain>=0.1.0 +langchain-openai>=0.1.0 +langchain-community>=0.0.20 + +# For Streamlit demo +streamlit>=1.28.0 + +# For RAG example +langchain-chroma>=0.1.0 \ No newline at end of file diff --git a/lettucedetect/models/generation.py b/lettucedetect/models/generation.py new file mode 100644 index 0000000..939b3f0 --- /dev/null +++ b/lettucedetect/models/generation.py @@ -0,0 +1,127 @@ +"""Simple hallucination generation using RAGFactChecker.""" + +from typing import Any, Dict, List, Optional + +from lettucedetect.ragfactchecker import RAGFactChecker + + +class HallucinationGenerator: + """Simple hallucination generator using RAGFactChecker. + + This provides the same interface as before but uses our clean RAGFactChecker wrapper. + """ + + def __init__( + self, + method: str = "rag_fact_checker", + openai_api_key: str = None, + model: str = "gpt-4o", + base_url: str = None, + temperature: float = 0.0, + **kwargs, + ): + """Initialize hallucination generator. + + :param method: Method name (kept for compatibility, only "rag_fact_checker" exists) + :param openai_api_key: OpenAI API key + :param model: OpenAI model to use (default: "gpt-4o") + :param base_url: Optional base URL for API (e.g., "http://localhost:1234/v1" for local servers) + :param temperature: Temperature for model sampling (default: 0.0 for deterministic outputs) + :param kwargs: Additional arguments (ignored) + + """ + self.rag = RAGFactChecker( + openai_api_key=openai_api_key, model=model, base_url=base_url, temperature=temperature + ) + + def generate( + self, + context: List[str], + question: str, + answer: str = None, + error_types: Optional[List[str]] = None, + intensity: float = 0.3, + ) -> Dict[str, Any]: + """Generate hallucinated content. + + :param context: List of context documents + :param question: Question to generate answer for + :param answer: Original answer (optional, for answer-based generation) + :param kwargs: Additional parameters + + :return: Generation results + + """ + if answer: + # Answer-based generation + return self.rag.generate_hallucination_from_answer( + answer, question, error_types, intensity + ) + else: + # Context-based generation + return self.rag.generate_hallucination_from_context( + context, question, error_types, intensity + ) + + def generate_batch( + self, + contexts: List[List[str]], + questions: List[str], + answers: List[str] = None, + error_types: Optional[List[str]] = None, + intensity: float = 0.3, + ) -> List[Dict[str, Any]]: + """Generate hallucinated content for multiple inputs. + + :param contexts: List of context lists + :param questions: List of questions + :param answers: List of answers (optional) + :param kwargs: Additional parameters + + :return: List of generation results + """ + if error_types: + error_types = [error_types] * len(contexts) + if intensity: + intensity = [intensity] * len(contexts) + + if answers: + return self.rag.generate_hallucination_from_answer_batch( + answers, questions, error_types, intensity + ) + else: + return self.rag.generate_hallucination_from_context_batch( + contexts, questions, error_types, intensity + ) + + async def generate_batch_async( + self, + contexts: List[List[str]], + questions: List[str], + answers: List[str] = None, + error_types: Optional[List[str]] = None, + intensity: float = 0.3, + ) -> List[Dict[str, Any]]: + """Generate hallucinated content for multiple inputs. + + :param contexts: List of context lists + :param questions: List of questions + :param answers: List of answers (optional) + :param kwargs: Additional parameters + + :return: List of generation results + + """ + if error_types: + error_types = [error_types] * len(contexts) + if intensity: + intensity = [intensity] * len(contexts) + + if answers: + return await self.rag.generate_hallucination_from_answer_batch_async( + answers, questions, error_types, intensity + ) + else: + return await self.rag.generate_hallucination_from_context_batch_async( + contexts, questions, error_types, intensity + ) diff --git a/lettucedetect/ragfactchecker.py b/lettucedetect/ragfactchecker.py new file mode 100644 index 0000000..c62d6f6 --- /dev/null +++ b/lettucedetect/ragfactchecker.py @@ -0,0 +1,360 @@ +"""Simple, clean RAGFactChecker wrapper for lettuceDetect.""" + +import logging +import os +from typing import Any, Dict, List, Optional + + +class RAGFactChecker: + """Simple wrapper around RAGFactChecker with a clean, unified API. + + This provides all RAGFactChecker functionality through one interface: + - Triplet generation and comparison + - Hallucination detection + - Hallucination generation + - Batch processing + """ + + def __init__( + self, + openai_api_key: Optional[str] = None, + model: str = "gpt-4o", + base_url: Optional[str] = None, + temperature: float = 0.0, + ): + """Initialize RAGFactChecker. + + :param openai_api_key: OpenAI API key. If None, uses OPENAI_API_KEY env var. + :param model: OpenAI model to use (default: "gpt-4o"). Options: "gpt-4o", "gpt-4", "gpt-3.5-turbo", etc. + :param base_url: Optional base URL for API (e.g., "http://localhost:1234/v1" for local servers). + :param temperature: Temperature for model sampling (default: 0.0 for deterministic outputs). + + :return: RAGFactChecker instance + """ + self.openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY") + if not self.openai_api_key: + raise ValueError( + "OpenAI API key required. Set OPENAI_API_KEY env var or pass explicitly." + ) + + self.model = model + self.base_url = base_url + self.temperature = temperature + self.logger = logging.getLogger(__name__) + self._setup_components() + + def _setup_components(self): + """Initialize RAGFactChecker components.""" + try: + from rag_fact_checker.data import Config + from rag_fact_checker.model.fact_checker import LLMFactChecker + from rag_fact_checker.model.hallucination_data_generator import ( + AnswerBasedHallucinationDataGenerator, + LLMHallucinationDataGenerator, + ) + from rag_fact_checker.model.triplet_generator import LLMTripletGenerator + + # Create config with defaults and API key + self.config = Config() + self.config.model.llm.api_key = self.openai_api_key + self.config.model.llm.generator_model = self.model + self.config.model.llm.temperature = self.temperature + if self.base_url: + self.config.model.llm.base_url = self.base_url + + # Initialize components + self.triplet_generator = LLMTripletGenerator(self.config, self.logger) + self.fact_checker = LLMFactChecker(self.config, self.logger) + self.reference_generator = LLMHallucinationDataGenerator(self.config, self.logger) + self.answer_generator = AnswerBasedHallucinationDataGenerator(self.config, self.logger) + + except ImportError as e: + raise ImportError( + "RAGFactChecker not available. Install with: pip install rag-fact-checker" + ) from e + + # ============ TRIPLET OPERATIONS ============ + + def generate_triplets(self, text: str) -> List[List[str]]: + """Generate triplets from text. + + :param text: Input text + + :return: List of triplets [subject, predicate, object] + List of triplets [subject, predicate, object] + + """ + result = self.triplet_generator.forward(text) + return result.triplets + + def compare_triplets( + self, answer_triplets: List[List[str]], reference_triplets: List[List[str]] + ) -> Dict[str, Any]: + """Compare answer triplets against reference triplets. + + :param answer_triplets: Triplets from answer to check + :param reference_triplets: Reference triplets to compare against + + :return: Dict with fact check results + + """ + result = self.fact_checker.forward( + answer_triplets=answer_triplets, reference_triplets=[reference_triplets] + ) + return {"fact_check_results": result.fact_check_prediction_binary, "raw_output": result} + + def analyze_text_pair(self, answer_text: str, reference_text: str) -> Dict[str, Any]: + """Generate and compare triplets for two texts. + + :param answer_text: Text to analyze + :param reference_text: Reference text to compare against + + :return: Complete analysis with triplets and comparison results + + """ + answer_triplets = self.generate_triplets(answer_text) + reference_triplets = self.generate_triplets(reference_text) + comparison = self.compare_triplets(answer_triplets, reference_triplets) + + return { + "answer_triplets": answer_triplets, + "reference_triplets": reference_triplets, + "comparison": comparison, + } + + # ============ HALLUCINATION DETECTION ============ + + def detect_hallucinations( + self, context: List[str], answer: str, question: Optional[str] = None + ) -> Dict[str, Any]: + """Detect hallucinations in answer given context. + + :param context: List of context documents + :param answer: Answer to check + :param question: Optional question for context + + :return: Detection results with triplets and fact checking + + """ + # Generate triplets + answer_triplets = self.generate_triplets(answer) + context_text = "\n".join(context) + context_triplets = self.generate_triplets(context_text) + + # Fact check + comparison = self.compare_triplets(answer_triplets, context_triplets) + + return { + "answer_triplets": answer_triplets, + "context_triplets": context_triplets, + "fact_check_results": comparison["fact_check_results"], + "hallucinated_triplets": [ + answer_triplets[i] + for i, fact_is_true in comparison["fact_check_results"].items() + if not fact_is_true and i < len(answer_triplets) + ], + } + + # ============ HALLUCINATION GENERATION ============ + + def generate_hallucination_from_context( + self, context: List[str], question: str + ) -> Dict[str, Any]: + """Generate hallucinated content from context and question. + + :param context: List of context documents + :param question: Question to answer + + :return: Generated hallucinated and non-hallucinated answers + + """ + context_text = "\n".join(context) + result = self.reference_generator.generate_hlcntn_data(context_text, question) + + return { + "hallucinated_answer": result.generated_hlcntn_answer, + "non_hallucinated_answer": result.generated_non_hlcntn_answer, + "hallucinated_parts": result.hlcntn_part, + } + + def generate_hallucination_from_answer( + self, + correct_answer: str, + question: str, + error_types: Optional[List[str]] = None, + intensity: float = 0.3, + ) -> Dict[str, Any]: + """Generate hallucinated version of a correct answer. + + :param correct_answer: The correct answer to modify + :param question: Original question for context + :param error_types: Types of errors to inject (factual, temporal, numerical, etc.) + :param intensity: Error intensity 0.1-1.0 + + :return: Generated hallucinated version with error details + + """ + # Convert string error types to ErrorType enums if provided + error_type_enums = None + if error_types: + from rag_fact_checker.model.hallucination_data_generator.answer_based_hallucination_data_generator import ( + ErrorType, + ) + + error_type_enums = [] + for error_type in error_types: + if hasattr(ErrorType, error_type.upper()): + error_type_enums.append(getattr(ErrorType, error_type.upper())) + + result = self.answer_generator.generate_answer_based_hallucination( + correct_answer=correct_answer, + question=question, + error_types=error_type_enums, + intensity=intensity, + ) + + return { + "original_answer": result.generated_non_hlcntn_answer, + "hallucinated_answer": result.generated_hlcntn_answer, + "hallucinated_parts": result.hlcntn_part, + } + + # ============ BATCH OPERATIONS ============ + + async def generate_hallucination_from_answer_batch_async( + self, + correct_answers: List[str], + questions: List[str], + error_types: Optional[List[List[str]]] = None, + intensities: Optional[List[float]] = None, + ) -> List[Dict[str, Any]]: + """Generate hallucinated version of multiple correct answers.""" + error_type_enums_list = None + if error_types: + from rag_fact_checker.model.hallucination_data_generator.answer_based_hallucination_data_generator import ( + ErrorType, + ) + + error_type_enums_list = [] + for error_type in error_types: + error_type_enums = [] + for error_type in error_type: + if hasattr(ErrorType, error_type.upper()): + error_type_enums.append(getattr(ErrorType, error_type.upper())) + error_type_enums_list.append(error_type_enums) + + result = await self.answer_generator.generate_answer_based_hallucination_batch_async( + correct_answers=correct_answers, + questions=questions, + error_types_list=error_type_enums_list, + intensities=intensities, + ) + return result + + async def generate_hallucination_from_context_batch_async( + self, + contexts: List[List[str]], + questions: List[str], + ) -> List[Dict[str, Any]]: + """Generate hallucinated version of multiple correct answers.""" + result = await self.reference_generator.generate_hlcntn_data_batch_async( + contexts, questions + ) + return result + + def generate_hallucination_from_answer_batch( + self, + correct_answers: List[str], + questions: List[str], + error_types: Optional[List[List[str]]] = None, + intensities: Optional[List[float]] = None, + ) -> List[Dict[str, Any]]: + """Generate hallucinated version of multiple correct answers. + + :param correct_answers: List of correct answers to modify + :param questions: List of original questions for context + :param error_types: List of lists of types of errors to inject (factual, temporal, numerical, etc.) + :param intensities: List of error intensities 0.1-1.0 + + :return: List of generated hallucinated versions with error details + + """ + error_type_enums_list = None + if error_types: + from rag_fact_checker.model.hallucination_data_generator.answer_based_hallucination_data_generator import ( + ErrorType, + ) + + error_type_enums_list = [] + for error_type in error_types: + error_type_enums = [] + for error_type in error_type: + if hasattr(ErrorType, error_type.upper()): + error_type_enums.append(getattr(ErrorType, error_type.upper())) + error_type_enums_list.append(error_type_enums) + + result = self.answer_generator.generate_answer_based_hallucination_batch( + correct_answers=correct_answers, + questions=questions, + error_types_list=error_type_enums_list, + intensities=intensities, + ) + return result + + def generate_hallucination_from_context_batch( + self, + contexts: List[List[str]], + questions: List[str], + ) -> List[Dict[str, Any]]: + """Generate hallucinated version of multiple correct answers. + + :param contexts: List of context document lists + :param questions: List of original questions for context + + :return: List of generated hallucinated versions with error details + + """ + result = self.reference_generator.generate_hlcntn_data_batch(contexts, questions) + return result + + def generate_triplets_batch(self, texts: List[str]) -> List[List[List[str]]]: + """Generate triplets for multiple texts. + + :param texts: List of input texts + + :return: List of triplet lists for each text + + """ + batch_result = self.triplet_generator.forward_batch(texts) + + # Create results list with empty lists for failed items + results = [[] for _ in texts] # Initialize with empty lists + + # Fill in successful results + result_index = 0 + for i in range(len(texts)): + if i not in batch_result.failed_indices: + if result_index < len(batch_result.results): + results[i] = batch_result.results[result_index].triplets + result_index += 1 + + return results + + def detect_hallucinations_batch( + self, contexts: List[List[str]], answers: List[str], questions: Optional[List[str]] = None + ) -> List[Dict[str, Any]]: + """Detect hallucinations for multiple context-answer pairs. + + :param contexts: List of context document lists + :param answers: List of answers to check + :param questions: Optional list of questions + + :return: List of detection results + + """ + results = [] + for i, (context, answer) in enumerate(zip(contexts, answers)): + question = questions[i] if questions and i < len(questions) else None + result = self.detect_hallucinations(context, answer, question) + results.append(result) + return results diff --git a/pyproject.toml b/pyproject.toml index 8995fb5..0d62884 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "lettucedetect" -version = "0.1.7" +version = "0.1.8" description = "Lettucedetect is a framework for detecting hallucinations in RAG applications." readme = {file = "README.md", content-type = "text/markdown"} requires-python = ">=3.10" @@ -23,7 +23,8 @@ dependencies = [ "tqdm>=4.65.0", "scikit-learn>=1.6.1", "numpy>=2.2.2", - "openai==1.66.3", + "openai>=1.66.3", + "rag-fact-checker", ] [project.urls] diff --git a/scripts/generate_synthetic_data.py b/scripts/generate_synthetic_data.py new file mode 100755 index 0000000..9fa7cb2 --- /dev/null +++ b/scripts/generate_synthetic_data.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +"""Generate synthetic hallucination data using RAGFactChecker.""" + +import argparse +import asyncio +import json +import logging +import os +import random +import sys +import time +from typing import Any, Dict, List, Optional + +from lettucedetect import HallucinationGenerator +from lettucedetect.detectors.prompt_utils import PromptUtils + +# Setup rich logging +try: + from rich.console import Console + from rich.logging import RichHandler + from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskID, + TextColumn, + TimeElapsedColumn, + ) + + RICH_AVAILABLE = True + console = Console() +except ImportError: + RICH_AVAILABLE = False + console = None + + +def setup_logging(verbose: bool = False) -> logging.Logger: + """Setup logging with rich output if available.""" + level = logging.DEBUG if verbose else logging.INFO + + if RICH_AVAILABLE: + logging.basicConfig( + level=level, + format="%(message)s", + datefmt="[%X]", + handlers=[RichHandler(rich_tracebacks=True, show_path=False)], + ) + else: + logging.basicConfig( + level=level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) + + return logging.getLogger(__name__) + + +def load_rag_mini_bioasq(split: str = "train", filter_min_words: int = 10) -> List[Dict[str, Any]]: + """Load rag-mini-bioasq dataset and prepare for generation.""" + try: + from datasets import load_dataset + except ImportError: + raise ImportError("datasets package required. Install with: pip install datasets") + + logger = logging.getLogger(__name__) + logger.info(f"Loading rag-mini-bioasq dataset ({split} split)...") + + # Load dataset + qa_dataset = load_dataset("enelpol/rag-mini-bioasq", "question-answer-passages") + corpus_dataset = load_dataset("enelpol/rag-mini-bioasq", "text-corpus") + + # Create corpus lookup + corpus_lookup = {item["id"]: item["passage"] for item in corpus_dataset["test"]} + + # Process data + processed_data = [] + for item in qa_dataset[split]: + passage_ids = item["relevant_passage_ids"] + context_passages = [corpus_lookup.get(pid, None) for pid in passage_ids] + context_passages = [p for p in context_passages if p is not None] + + # Filter by answer length + if len(item["answer"].split()) >= filter_min_words: + processed_data.append( + { + "question": item["question"], + "answer": item["answer"], + "context": context_passages, + } + ) + + logger.info( + f"Loaded {len(processed_data)} samples after filtering (min {filter_min_words} words)" + ) + return processed_data + + +def load_custom_dataset(file_path: str) -> List[Dict[str, Any]]: + """Load custom dataset from JSON file.""" + logger = logging.getLogger(__name__) + logger.info(f"Loading custom dataset from {file_path}...") + + with open(file_path) as f: + data = json.load(f) + + # Validate format + required_fields = ["question", "context"] + for i, item in enumerate(data): + for field in required_fields: + if field not in item: + raise ValueError(f"Missing required field '{field}' in item {i}") + + logger.info(f"Loaded {len(data)} samples from custom dataset") + return data + + +async def generate_batch_async( + generator: HallucinationGenerator, + samples: List[Dict[str, Any]], + method: str = "answer_based", + error_types: Optional[List[str]] = None, + intensity: float = 0.3, +) -> List[Dict[str, Any]]: + """Generate hallucinated data for a batch of samples.""" + logger = logging.getLogger(__name__) + + if method == "answer_based": + # Use existing answers + contexts = [sample["context"] for sample in samples] + questions = [sample["question"] for sample in samples] + answers = [sample["answer"] for sample in samples] + + result = await generator.generate_batch_async( + contexts=contexts, + questions=questions, + answers=answers, + error_types=error_types, + intensity=intensity, + ) + else: + # Context-based generation + contexts = [sample["context"] for sample in samples] + questions = [sample["question"] for sample in samples] + + result = await generator.generate_batch_async( + contexts=contexts, questions=questions, error_types=error_types, intensity=intensity + ) + + return result.results if hasattr(result, "results") else result + + +def convert_to_ragtruth_format( + samples: List[Dict[str, Any]], + results: List[Any], + language: str = "en", + dataset_name: str = "synthetic", +) -> List[Dict[str, Any]]: + """Convert generation results to RAGTruth format.""" + ragtruth_data = [] + + for i, (sample, result) in enumerate(zip(samples, results)): + # Format context using prompt utils + formatted_prompt = PromptUtils.format_context( + sample["context"], sample["question"], lang=language + ) + + # Original answer (non-hallucinated) + if hasattr(result, "generated_non_hlcntn_answer"): + real_answer = result.generated_non_hlcntn_answer + else: + real_answer = sample.get("answer", "") + + ragtruth_data.append( + { + "prompt": formatted_prompt, + "answer": real_answer, + "labels": [], + "split": "train", + "task_type": "qa", + "dataset": dataset_name, + "language": language, + } + ) + + # Hallucinated answer with labels + if hasattr(result, "generated_hlcntn_answer"): + hallucinated_answer = result.generated_hlcntn_answer + hallucinated_labels = [] + + # Create span labels from hallucinated parts + if hasattr(result, "hlcntn_part") and result.hlcntn_part: + for part in result.hlcntn_part: + if isinstance(part, str) and part in hallucinated_answer: + start = hallucinated_answer.find(part) + if start != -1: + hallucinated_labels.append( + {"start": start, "end": start + len(part), "label": "hallucinated"} + ) + + ragtruth_data.append( + { + "prompt": formatted_prompt, + "answer": hallucinated_answer, + "labels": hallucinated_labels, + "split": "train", + "task_type": "qa", + "dataset": dataset_name, + "language": language, + } + ) + + return ragtruth_data + + +async def generate_synthetic_data( + samples: List[Dict[str, Any]], + num_samples: int, + model: str = "gpt-4o", + base_url: Optional[str] = None, + temperature: float = 0.0, + method: str = "answer_based", + error_types: Optional[List[str]] = None, + intensity: float = 0.3, + batch_size: int = 10, + output_format: str = "json", + language: str = "en", + dataset_name: str = "synthetic", +) -> List[Dict[str, Any]]: + """Generate synthetic hallucination data.""" + logger = logging.getLogger(__name__) + + # Initialize generator + generator = HallucinationGenerator( + method="rag_fact_checker", model=model, base_url=base_url, temperature=temperature + ) + + # Sample data if needed + if num_samples < len(samples): + samples = random.sample(samples, num_samples) + logger.info(f"Randomly sampled {num_samples} examples from dataset") + else: + samples = samples[:num_samples] + + # Process in batches + all_results = [] + + if RICH_AVAILABLE: + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + TimeElapsedColumn(), + console=console, + ) as progress: + task = progress.add_task(f"Generating hallucinations ({method})", total=len(samples)) + + for i in range(0, len(samples), batch_size): + batch = samples[i : i + batch_size] + + try: + batch_results = await generate_batch_async( + generator, batch, method, error_types, intensity + ) + all_results.extend(batch_results) + + progress.update(task, advance=len(batch)) + logger.debug( + f"Completed batch {i // batch_size + 1}/{(len(samples) + batch_size - 1) // batch_size}" + ) + + except Exception as e: + logger.error(f"Error processing batch {i // batch_size + 1}: {e}") + continue + else: + # Fallback without rich + for i in range(0, len(samples), batch_size): + batch = samples[i : i + batch_size] + logger.info( + f"Processing batch {i // batch_size + 1}/{(len(samples) + batch_size - 1) // batch_size}" + ) + + try: + batch_results = await generate_batch_async( + generator, batch, method, error_types, intensity + ) + all_results.extend(batch_results) + + except Exception as e: + logger.error(f"Error processing batch {i // batch_size + 1}: {e}") + continue + + logger.info(f"Generated {len(all_results)} hallucination samples") + + # Convert to requested format + if output_format == "ragtruth": + return convert_to_ragtruth_format(samples, all_results, language, dataset_name) + else: + # Standard JSON format + formatted_results = [] + for sample, result in zip(samples, all_results): + formatted_result = { + "question": sample["question"], + "context": sample["context"], + "method": method, + "model": model, + "temperature": temperature, + } + + if hasattr(result, "generated_non_hlcntn_answer"): + formatted_result["original_answer"] = result.generated_non_hlcntn_answer + if hasattr(result, "generated_hlcntn_answer"): + formatted_result["hallucinated_answer"] = result.generated_hlcntn_answer + if hasattr(result, "hlcntn_part"): + formatted_result["hallucinated_parts"] = result.hlcntn_part + + formatted_results.append(formatted_result) + + return formatted_results + + +def print_statistics(results: List[Dict[str, Any]], output_format: str): + """Print generation statistics.""" + logger = logging.getLogger(__name__) + + if not results: + logger.warning("No results to analyze") + return + + total_samples = len(results) + + if output_format == "ragtruth": + # Count hallucinated vs non-hallucinated samples + hallucinated_count = sum(1 for r in results if r.get("labels")) + non_hallucinated_count = total_samples - hallucinated_count + + logger.info("📊 Generation Statistics:") + logger.info(f" Total samples: {total_samples}") + logger.info(f" Hallucinated samples: {hallucinated_count}") + logger.info(f" Non-hallucinated samples: {non_hallucinated_count}") + + if hallucinated_count > 0: + # Average number of hallucination spans + total_spans = sum(len(r.get("labels", [])) for r in results if r.get("labels")) + avg_spans = total_spans / hallucinated_count + logger.info(f" Average spans per hallucinated sample: {avg_spans:.1f}") + else: + logger.info("📊 Generation Statistics:") + logger.info(f" Total samples: {total_samples}") + + # Calculate average lengths + if results and "hallucinated_answer" in results[0]: + avg_hal_len = ( + sum(len(r["hallucinated_answer"].split()) for r in results) / total_samples + ) + logger.info(f" Average hallucinated answer length: {avg_hal_len:.1f} words") + + +async def main(): + """Main function.""" + parser = argparse.ArgumentParser( + description="Generate synthetic hallucination data using RAGFactChecker", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Generate from rag-mini-bioasq dataset + python scripts/generate_synthetic_data.py \\ + --dataset rag-mini-bioasq \\ + --split train \\ + --num-samples 100 \\ + --model gpt-4o-mini \\ + --output data/synthetic_train.json + + # Generate with custom parameters + python scripts/generate_synthetic_data.py \\ + --dataset rag-mini-bioasq \\ + --split test \\ + --num-samples 50 \\ + --model gpt-4o \\ + --temperature 0.7 \\ + --error-types factual temporal numerical \\ + --intensity 0.5 \\ + --output-format ragtruth \\ + --output data/synthetic_test_ragtruth.json + """, + ) + + # Data source + data_group = parser.add_mutually_exclusive_group(required=True) + data_group.add_argument("--dataset", choices=["rag-mini-bioasq"], help="Use built-in dataset") + data_group.add_argument("--custom-data", type=str, help="Path to custom JSON dataset file") + + # Dataset options + parser.add_argument( + "--split", + choices=["train", "test"], + default="train", + help="Dataset split to use (default: train)", + ) + parser.add_argument( + "--num-samples", type=int, default=100, help="Number of samples to generate (default: 100)" + ) + parser.add_argument( + "--filter-min-words", + type=int, + default=10, + help="Minimum words in answer for filtering (default: 10)", + ) + + # Generation parameters + parser.add_argument("--model", default="gpt-4o", help="OpenAI model to use (default: gpt-4o)") + parser.add_argument( + "--base-url", type=str, help="Base URL for OpenAI-compatible API (for local models)" + ) + parser.add_argument( + "--temperature", type=float, default=0.0, help="Temperature for generation (default: 0.0)" + ) + parser.add_argument( + "--method", + choices=["context_based", "answer_based"], + default="answer_based", + help="Generation method (default: answer_based)", + ) + parser.add_argument( + "--error-types", + nargs="+", + choices=["factual", "temporal", "numerical", "logical", "causal"], + default=None, + help="Error types for answer-based generation (default: None)", + ) + parser.add_argument( + "--intensity", type=float, default=0.3, help="Error intensity 0.1-1.0 (default: 0.3)" + ) + parser.add_argument( + "--batch-size", type=int, default=5, help="Batch size for processing (default: 5)" + ) + + # Output options + parser.add_argument("--output", required=True, help="Output file path") + parser.add_argument( + "--output-format", + choices=["json", "ragtruth"], + default="json", + help="Output format (default: json)", + ) + parser.add_argument( + "--language", default="en", help="Language code for RAGTruth format (default: en)" + ) + parser.add_argument( + "--dataset-name", + default="synthetic", + help="Dataset name for RAGTruth format (default: synthetic)", + ) + + # Logging + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging") + + args = parser.parse_args() + + # Setup logging + logger = setup_logging(args.verbose) + + # Check API key + if not os.getenv("OPENAI_API_KEY"): + logger.error("OPENAI_API_KEY environment variable is required") + sys.exit(1) + + # Load data + try: + if args.dataset == "rag-mini-bioasq": + samples = load_rag_mini_bioasq(args.split, args.filter_min_words) + else: + samples = load_custom_dataset(args.custom_data) + + except Exception as e: + logger.error(f"Failed to load dataset: {e}") + sys.exit(1) + + # Validate parameters + if args.num_samples <= 0: + logger.error("Number of samples must be positive") + sys.exit(1) + + if not (0.1 <= args.intensity <= 1.0): + logger.error("Intensity must be between 0.1 and 1.0") + sys.exit(1) + + # Generate data + start_time = time.time() + + try: + results = await generate_synthetic_data( + samples=samples, + num_samples=args.num_samples, + model=args.model, + base_url=args.base_url, + temperature=args.temperature, + method=args.method, + error_types=args.error_types, + intensity=args.intensity, + batch_size=args.batch_size, + output_format=args.output_format, + language=args.language, + dataset_name=args.dataset_name, + ) + + # Save results + os.makedirs(os.path.dirname(args.output), exist_ok=True) + with open(args.output, "w") as f: + json.dump(results, f, indent=2) + + elapsed_time = time.time() - start_time + + # Print statistics + print_statistics(results, args.output_format) + logger.info(f"Generated {len(results)} samples in {elapsed_time:.1f}s") + logger.info(f"Results saved to {args.output}") + + except KeyboardInterrupt: + logger.info("Generation interrupted by user") + sys.exit(1) + except Exception as e: + logger.error(f"Generation failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main())