diff --git a/lettucedetect/datasets/hallucination_dataset.py b/lettucedetect/datasets/hallucination_dataset.py index 64e2644..2c20b39 100644 --- a/lettucedetect/datasets/hallucination_dataset.py +++ b/lettucedetect/datasets/hallucination_dataset.py @@ -1,10 +1,13 @@ from dataclasses import dataclass from typing import Literal +import nltk import torch from torch.utils.data import Dataset from transformers import AutoTokenizer +nltk.download("punkt_tab") + @dataclass class HallucinationSample: @@ -15,9 +18,10 @@ class HallucinationSample: task_type: str dataset: Literal["ragtruth", "ragbench"] language: Literal["en", "de"] + answer_sentences: list = None def to_json(self) -> dict: - return { + json_dict = { "prompt": self.prompt, "answer": self.answer, "labels": self.labels, @@ -27,11 +31,17 @@ def to_json(self) -> dict: "language": self.language, } + if self.answer_sentences is not None: + json_dict["answer_sentences"] = self.answer_sentences + + return json_dict + @classmethod def from_json(cls, json_dict: dict) -> "HallucinationSample": return cls( prompt=json_dict["prompt"], answer=json_dict["answer"], + answer_sentences=json_dict.get("answer_sentences"), labels=json_dict["labels"], split=json_dict["split"], task_type=json_dict["task_type"], @@ -54,6 +64,21 @@ def from_json(cls, json_dict: list[dict]) -> "HallucinationData": ) +def find_hallucinated_sent(sample): + hallu_sent = [] + for label in sample.labels: + hallu_sent.append(sample.answer[label["start"] : label["end"]]) + return hallu_sent + + +def define_sentence_label(sentences, hallucinated_sentences): + labels = [ + int(any(hallu_sent in sentence for hallu_sent in hallucinated_sentences)) + for sentence in sentences + ] + return labels + + class HallucinationDataset(Dataset): """Dataset for Hallucination data.""" @@ -61,9 +86,10 @@ def __init__( self, samples: list[HallucinationSample], tokenizer: AutoTokenizer, + method: Literal["transformer", "sentencetransformer"] = "transformer", max_length: int = 4096, ): - """Initialize the dataset. + """Initialize the dataset.x :param samples: List of HallucinationSample objects. :param tokenizer: Tokenizer to use for encoding the data. @@ -71,6 +97,7 @@ def __init__( """ self.samples = samples self.tokenizer = tokenizer + self.method = method self.max_length = max_length def __len__(self) -> int: @@ -128,6 +155,122 @@ def prepare_tokenized_input( return encoding, labels, offsets, answer_start_token + @classmethod + def encode_context_and_sentences_with_offset( + cls, + tokenizer: AutoTokenizer, + context: str, + sentences: list, + max_length: int = 4096, + ) -> dict: + max_length = max_length - 2 + + # ------------------------------------------------------------------------- + # 1) Encode the context with special tokens + # ------------------------------------------------------------------------- + encoded_context = tokenizer.encode_plus( + context, + add_special_tokens=True, + return_offsets_mapping=True, + max_length=max_length, + truncation=True, + ) + context_ids = encoded_context["input_ids"] + context_attn_mask = encoded_context["attention_mask"] + context_offsets = encoded_context["offset_mapping"] + + if len(context_ids) > 1 and context_ids[-1] == tokenizer.sep_token_id: + context_ids.pop() + context_attn_mask.pop() + context_offsets.pop() + + input_ids = context_ids[:] + attention_mask = context_attn_mask[:] + offset_mapping = context_offsets[:] + + sentence_boundaries = [] + sentence_offset_mappings = [] + + # ------------------------------------------------------------------------- + # 2) Encode each sentence and check if it fits within max_length + # ------------------------------------------------------------------------- + for sent in sentences: + # First check if adding this sentence would exceed max_length + # Encode the sentence to check its length + encoded_sent = tokenizer.encode_plus( + sent, + add_special_tokens=False, + return_offsets_mapping=True, + max_length=max_length, + truncation=True, + ) + + sent_ids = encoded_sent["input_ids"] + sent_offsets = encoded_sent["offset_mapping"] + + # +1 for [SEP] token + if len(input_ids) + len(sent_ids) + 1 > max_length: + # If this sentence won't fit, stop processing more sentences + break + + # If we get here, we can add the sentence + # Insert [SEP] for boundary + + input_ids.append(tokenizer.sep_token_id) + attention_mask.append(1) + offset_mapping.append((0, 0)) + + sent_start_idx = len(input_ids) + + # Add the sentence tokens + input_ids.extend(sent_ids) + attention_mask.extend([1] * len(sent_ids)) + offset_mapping.extend(sent_offsets) + + sent_end_idx = len(input_ids) - 1 # inclusive end + + # Mark this sentence boundary and store its offsets and label + sentence_boundaries.append((sent_start_idx, sent_end_idx)) + sentence_offset_mappings.append(sent_offsets) + + # Add final [SEP] if there's room + if len(input_ids) < max_length: + input_ids.append(tokenizer.sep_token_id) + attention_mask.append(1) + offset_mapping.append((0, 0)) + + # ------------------------------------------------------------------------- + # 3) Handle truncation by only including complete sentences + # ------------------------------------------------------------------------- + if len(input_ids) > max_length: + # Find the last complete sentence that fits + last_valid_idx = 0 + for i, (start, end) in enumerate(sentence_boundaries): + if end < max_length: + last_valid_idx = i + else: + break + + if last_valid_idx >= 0: + last_token = sentence_boundaries[last_valid_idx][1] + input_ids = input_ids[: last_token + 1] # +1 to include the last [SEP] + attention_mask = attention_mask[: last_token + 1] + offset_mapping = offset_mapping[: last_token + 1] + sentence_boundaries = sentence_boundaries[: last_valid_idx + 1] + sentence_offset_mappings = sentence_offset_mappings[: last_valid_idx + 1] + + # Convert to tensors + input_ids = torch.tensor(input_ids, dtype=torch.long) + attention_mask = torch.tensor(attention_mask, dtype=torch.long) + + return ( + input_ids, + attention_mask, + offset_mapping, + sentence_boundaries, + sentence_offset_mappings, + ) + def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """Get an item from the dataset. @@ -136,39 +279,82 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """ sample = self.samples[idx] - # Use the shared class method to perform tokenization and initial label setup. - encoding, labels, offsets, answer_start = HallucinationDataset.prepare_tokenized_input( - self.tokenizer, sample.prompt, sample.answer, self.max_length - ) - # Adjust the token labels based on the annotated hallucination spans. - # Compute the character offset of the first answer token. - - answer_char_offset = offsets[answer_start][0] if answer_start < len(offsets) else None + # ------------------------------------------------------------------------- + # 1) Token-level Model + # ------------------------------------------------------------------------- - for i in range(answer_start, encoding["input_ids"].shape[1]): - token_start, token_end = offsets[i] - # Adjust token offsets relative to answer text. - token_abs_start = ( - token_start - answer_char_offset if answer_char_offset is not None else token_start - ) - token_abs_end = ( - token_end - answer_char_offset if answer_char_offset is not None else token_end + if self.method == "transformer": + # Use the shared class method to perform tokenization and initial label setup. + encoding, labels, offsets, answer_start = HallucinationDataset.prepare_tokenized_input( + self.tokenizer, sample.prompt, sample.answer, self.max_length ) + # Adjust the token labels based on the annotated hallucination spans. + # Compute the character offset of the first answer token. - # Default label is 0 (supported content). - token_label = 0 - # If token overlaps any annotated hallucination span, mark it as hallucinated (1). - for ann in sample.labels: - if token_abs_end > ann["start"] and token_abs_start < ann["end"]: - token_label = 1 - break + answer_char_offset = offsets[answer_start][0] if answer_start < len(offsets) else None - labels[i] = token_label + for i in range(answer_start, encoding["input_ids"].shape[1]): + token_start, token_end = offsets[i] + # Adjust token offsets relative to answer text. + token_abs_start = ( + token_start - answer_char_offset + if answer_char_offset is not None + else token_start + ) + token_abs_end = ( + token_end - answer_char_offset if answer_char_offset is not None else token_end + ) - labels = torch.tensor(labels, dtype=torch.long) + # Default label is 0 (supported content). + token_label = 0 + # If token overlaps any annotated hallucination span, mark it as hallucinated (1). + for ann in sample.labels: + if token_abs_end > ann["start"] and token_abs_start < ann["end"]: + token_label = 1 + break - return { - "input_ids": encoding["input_ids"].squeeze(0), - "attention_mask": encoding["attention_mask"].squeeze(0), - "labels": labels, - } + labels[i] = token_label + + labels = torch.tensor(labels, dtype=torch.long) + + return { + "input_ids": encoding["input_ids"].squeeze(0), + "attention_mask": encoding["attention_mask"].squeeze(0), + "labels": labels, + } + + # ------------------------------------------------------------------------- + # 2) Sentence-Level Model + # ------------------------------------------------------------------------- + else: + # If the sample is coming from ragbench we will use the response sentences already defined in the dataset; otherwise the sample.answer will be split using nltk library + sentences = sample.answer_sentences + if sentences is None: + sentences = nltk.sent_tokenize(sample.answer) + + ( + input_ids, + attention_mask, + offset_mapping, + sentence_boundaries, + sentence_offset_mappings, + ) = HallucinationDataset.encode_context_and_sentences_with_offset( + self.tokenizer, sample.prompt, sentences, max_length=4096 + ) + + # Add labels for included sentences + hallucinated_sentences = find_hallucinated_sent(sample=sample) + sentence_labels = define_sentence_label( + sentences=sentences[: len(sentence_boundaries)], + hallucinated_sentences=hallucinated_sentences, + ) + sentence_labels = torch.tensor(sentence_labels, dtype=torch.long) + + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + "offset_mapping": offset_mapping, + "sentence_boundaries": sentence_boundaries, + "sentence_offset_mappings": sentence_offset_mappings, + "labels": sentence_labels, + } diff --git a/lettucedetect/detectors/factory.py b/lettucedetect/detectors/factory.py index 430e616..a26b586 100644 --- a/lettucedetect/detectors/factory.py +++ b/lettucedetect/detectors/factory.py @@ -19,6 +19,10 @@ def make_detector(method: str, **kwargs) -> BaseDetector: from lettucedetect.detectors.transformer import TransformerDetector return TransformerDetector(**kwargs) + elif method == "sentencetransformer": + from lettucedetect.detectors.sentence_transformer import SentenceTransformer + + return SentenceTransformer(**kwargs) elif method == "llm": from lettucedetect.detectors.llm import LLMDetector diff --git a/lettucedetect/detectors/sentence_transformer.py b/lettucedetect/detectors/sentence_transformer.py new file mode 100644 index 0000000..90b1a94 --- /dev/null +++ b/lettucedetect/detectors/sentence_transformer.py @@ -0,0 +1,140 @@ +"""SentenceTransformer‑based hallucination detector.""" + +from __future__ import annotations + +from typing import Literal + +import nltk +import torch +from transformers import AutoModelForTokenClassification, AutoTokenizer + +from lettucedetect.datasets.hallucination_dataset import HallucinationDataset +from lettucedetect.detectors.base import BaseDetector +from lettucedetect.detectors.prompt_utils import LANG_TO_PASSAGE, Lang, PromptUtils +from lettucedetect.models.sentence_model import SentenceModel + +__all__ = ["SentenceTransformer"] + + +def to_sentences(answer): + if isinstance(answer, list): + return answer + return nltk.sent_tokenize(answer) + + +class SentenceTransformer(BaseDetector): + """Detect hallucinations with a fine‑tuned sentence classifier.""" + + def __init__( + self, + model_path: str, + max_length: int = 4096, + device=None, + lang: Literal["en", "de", "fr", "es", "it", "pl"] = "en", + threshold: float = 0.5, + **kwargs, + ): + """Initialize the SentenceTransformer. + :param model_path: The path to the model. + :param max_length: The maximum length of the input sequence. + :param device: The device to run the model on. + :param lang: The language of the model. + :param threshold: Confidence threshold for considering a span relevant (0.0-1.0) + """ + + self.lang = lang + self.model = SentenceModel.from_pretrained(model_path, **kwargs) + base_model = getattr(self.model.config, "model_name", "answerdotai/ModernBERT-base") + self.tokenizer = AutoTokenizer.from_pretrained(base_model) + self.max_length = max_length + self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.threshold = threshold + + self.model.to(self.device).eval() + + def _predict(self, context: str, answer: str, output_format: str): + """Predict hallucination tokens or spans from the provided context and answer. + + :param context: The context string. + :param answer: The answer string. + """ + + if output_format == "spans": + sentences = to_sentences(answer) + # Use the shared tokenization logic from HallucinationDataset + ( + input_ids, + attention_mask, + offset_mapping, + sentence_boundaries, + sentence_offset_mappings, + ) = HallucinationDataset.encode_context_and_sentences_with_offset( + self.tokenizer, context, sentences, self.max_length + ) + + input_ids = input_ids.unsqueeze(0).to(self.device) + attention_mask = attention_mask.unsqueeze(0).to(self.device) + + # Run model inference + with torch.no_grad(): + outputs = self.model(input_ids, attention_mask, [sentence_boundaries]) + # print(outputs) + # Extract hallucinated sentences + hallucinated_sentences = [] + # print(outputs) + if len(outputs) > 0 and outputs[0] is not None and len(outputs[0]) > 0: + sentence_preds = torch.nn.functional.softmax(outputs[0], dim=1) + for i, pred in enumerate(sentence_preds): + if i < len(sentences) and pred[1] > self.threshold: + hallucinated_sentences.append(sentences[i]) + + return hallucinated_sentences + else: + raise ValueError( + "Invalid output_format. This model can only predict hallucination sentences. Use spans." + ) + + def predict_prompt(self, prompt: str, answer: str, output_format: str = "spans") -> list: + """Predict hallucination tokens or spans from the provided prompt and answer. + + :param prompt: The prompt string. + :param answer: The answer string. + :param output_format: "spans" for sentences, + """ + return self._predict(prompt, answer, output_format) + + def predict( + self, + context: list[str], + answer: str, + question: str | None = None, + output_format: str = "spans", + ) -> list: + """Predict hallucination tokens or spans from the provided context, answer, and question. + This is a useful interface when we don't want to predict a specific prompt, but rather we have a list of contexts, answers, and questions. Useful to interface with RAG systems. + + :param context: A list of context strings. + :param answer: The answer string. + :param question: The question string. + :param output_format: "spans" to return sentences. + """ + formatted_prompt = PromptUtils.format_context(context, question, self.lang) + return self._predict(formatted_prompt, answer, output_format) + + def predict_prompt(self, prompt, answer, output_format="tokens") -> list: + """Predict hallucination sentences from the provided prompt and answer. + + :param prompt: The prompt string. + :param answer: The answer string. + :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + """ + return self._predict(prompt, answer, output_format) + + def predict_prompt_batch(self, prompts, answers, output_format="tokens") -> list: + """Predict hallucination sentences from the provided prompts and answers. + + :param prompts: List of prompt strings. + :param answers: List of answer strings. + :param output_format: "tokens" to return token-level predictions, or "spans" to return grouped spans. + """ + return [self._predict(p, a, output_format) for p, a in zip(prompts, answers)] diff --git a/lettucedetect/models/evaluator.py b/lettucedetect/models/evaluator.py index 52e9c2c..cbbc425 100644 --- a/lettucedetect/models/evaluator.py +++ b/lettucedetect/models/evaluator.py @@ -1,5 +1,10 @@ +import json +import logging + import torch +import torch.nn as nn from sklearn.metrics import ( + accuracy_score, auc, classification_report, precision_recall_fscore_support, @@ -12,6 +17,9 @@ from lettucedetect.datasets.hallucination_dataset import HallucinationSample from lettucedetect.models.inference import HallucinationDetector +# Set up logger +logger = logging.getLogger(__name__) + def evaluate_model( model: Module, @@ -297,7 +305,10 @@ def evaluate_detector_example_level_batch( for i in tqdm(range(0, len(samples), batch_size), desc="Evaluating", leave=False): batch = samples[i : i + batch_size] prompts = [sample.prompt for sample in batch] - answers = [sample.answer for sample in batch] + answers = [ + sample.answer_sentences if sample.answer_sentences else sample.answer + for sample in batch + ] predicted_spans = detector.predict_prompt_batch(prompts, answers, output_format="spans") for sample, pred_spans in zip(batch, predicted_spans): @@ -375,7 +386,7 @@ def evaluate_detector_example_level( for sample in tqdm(samples, desc="Evaluating", leave=False): prompt = sample.prompt - answer = sample.answer + answer = sample.answer_sentences if sample.answer_sentences else sample.answer gold_spans = sample.labels predicted_spans = detector.predict_prompt(prompt, answer, output_format="spans") true_example_label = 1 if gold_spans else 0 @@ -420,3 +431,162 @@ def evaluate_detector_example_level( results["classification_report"] = report return results + + +def evaluate_sentence_model( + model: nn.Module, test_loader: DataLoader, device: torch.device, criterion, verbose: bool = True +) -> dict[str, dict[str, float]]: + """Evaluate the model on the test dataset""" + model.eval() + total_loss = 0.0 + step_count = 0 + + all_preds = [] + all_labels = [] + try: + with torch.no_grad(): + progress_bar = tqdm(test_loader, desc="Evaluating", leave=False) + for batch in progress_bar: + try: + input_ids = batch["input_ids"].to(device) + attention_mask = batch["attention_mask"].to(device) + sentence_boundaries = batch["sentence_boundaries"] + labels_list = batch["labels"] + + logits_list = model(input_ids, attention_mask, sentence_boundaries) + + batch_loss = 0.0 + doc_count = 0 + + for i, logits in enumerate(logits_list): + # skip if we have a mismatch in lists + if i >= len(labels_list): + logger.warning( + f"Mismatch between logits and labels list lengths: {len(logits_list)} vs {len(labels_list)}" + ) + continue + + labels_i = labels_list[i].to(device) + if logits.size(0) == 0: + continue + + # Make sure sizes match + effective_size = min(logits.size(0), labels_i.size(0)) + if logits.size(0) != labels_i.size(0): + logger.warning( + f"Mismatch between logits and labels sizes: {logits.size(0)} vs {labels_i.size(0)}" + ) + logits = logits[:effective_size] + labels_i = labels_i[:effective_size] + + # Calculate loss + loss_i = criterion(logits, labels_i) + batch_loss += loss_i + doc_count += 1 + + # Get predictions for metrics + preds_i = torch.argmax(logits, dim=1).cpu().numpy() + labels_i_np = labels_i.cpu().numpy() + + # Extend lists with batch predictions and labels + all_preds.extend(preds_i) + all_labels.extend(labels_i_np) + + if doc_count > 0: + batch_loss = batch_loss / doc_count + total_loss += batch_loss.item() + step_count += 1 + + # Update progress bar with loss + progress_bar.set_postfix({"loss": f"{batch_loss.item():.4f}"}) + except Exception as e: + logger.error(f"Error evaluating batch: {e}") + continue + except Exception as e: + logger.error(f"Error during evaluation: {e}") + # if we have no results still try to return partial metrics + + # Calculate metrics + results = {} + + if step_count == 0: + logger.warning("No evaluation steps completed") + + results["supported"] = { # Class 0 + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + } + results["hallucinated"] = { # Class 1 + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + } + results["auroc"] = 0.0 + results["accuracy"] = 0.0 + return results + + results["loss"] = total_loss / step_count + + try: + if len(all_preds) > 0: + precision, recall, f1, _ = precision_recall_fscore_support( + all_labels, all_preds, average=None, labels=[0, 1], zero_division=0 + ) + accuracy = accuracy_score(all_labels, all_preds) + + # Calculating AUROC + fpr, tpr, _ = roc_curve(all_labels, all_preds) + auroc = auc(fpr, tpr) + + results["supported"] = { # Class 0 + "precision": float(precision[0]), + "recall": float(recall[0]), + "f1": float(f1[0]), + } + results["hallucinated"] = { # Class 1 + "precision": float(precision[1]), + "recall": float(recall[1]), + "f1": float(f1[1]), + } + results["auroc"] = auroc + results["accuracy"] = accuracy + + else: + logger.warning("No predictions collected during evaluation") + results["supported"] = { # Class 0 + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + } + results["hallucinated"] = { # Class 1 + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + } + results["auroc"] = 0.0 + results["accuracy"] = 0.0 + except Exception as e: + logger.error(f"Error calculating metrics: {e}") + results["supported"] = { # Class 0 + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + } + results["hallucinated"] = { # Class 1 + "precision": 0.0, + "recall": 0.0, + "f1": 0.0, + } + results["auroc"] = 0.0 + results["accuracy"] = 0.0 + + if verbose: + report = classification_report( + all_labels, all_preds, target_names=["Supported", "Hallucinated"], digits=4 + ) + print("\nDetailed Classification Report:") + print(report) + results["classification_report"] = report + + return results diff --git a/lettucedetect/models/sentece_model.py b/lettucedetect/models/sentece_model.py new file mode 100644 index 0000000..79457dc --- /dev/null +++ b/lettucedetect/models/sentece_model.py @@ -0,0 +1,143 @@ +import logging +from typing import Any + +import torch +import torch.nn as nn +from transformers import AutoConfig, AutoModel, PreTrainedModel + +logger = logging.getLogger(__name__) + + +class SentenceModel(PreTrainedModel): + """ + Sentence Level Model for sentence classification. + """ + + config_class = AutoConfig + base_model_prefix = "bert" + + def __init__( + self, + config=None, + model_name="answerdotai/ModernBERT-base", + hidden_dim: int = 768, + num_labels=2, + ): + # Create config if not provided + if config is None: + config = AutoConfig.from_pretrained(model_name) + # Add our custom config values + config.model_name = model_name + config.hidden_dim = hidden_dim + config.num_labels = num_labels + + super().__init__(config) + + # Set properties from config + self.model_name = getattr(config, "model_name", model_name) + self.hidden_dim = getattr(config, "hidden_dim", hidden_dim) + self.num_labels = getattr(config, "num_labels", num_labels) + + # Initialize base model + self.bert = AutoModel.from_pretrained(self.model_name) + + # Two classification heads + self.sentence_classifier = nn.Linear(self.hidden_dim, self.num_labels) + # Initialize weights + self.init_weights() + + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + sentence_boundaries: list[list[tuple[int, int]]], + ) -> list[torch.Tensor]: + """ + Forward pass of the model. + + :param input_ids: Token IDs + :param attention_mask: Attention mask + :param sentence_boundaries: List of lists of tuples (start, end) for sentence boundaries + + :return: List of tensors with sentence classification logits + """ + # Get contextualized representations from BERT + outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) + sequence_output = outputs[0] # (batch_size, seq_len, hidden_size) + + # Extract sentence representations and classify each sentence + batch_size = sequence_output.size(0) + sentence_preds = [] + + for batch_idx in range(batch_size): + # Get the sentence boundaries for this batch item + batch_sentence_boundaries = sentence_boundaries[batch_idx] + + # Collect sentence representations + sentence_reprs = [] + for start, end in batch_sentence_boundaries: + # If the sentence extends beyond the sequence, adjust the end + if end >= sequence_output.size(1): + end = sequence_output.size(1) - 1 + + # Skip empty or invalid sentences + if end < start or start < 0: + continue + + # Get the token embeddings for this sentence + sentence_tokens = sequence_output[batch_idx, start : end + 1] + + # Average the token embeddings to get a sentence embedding + sentence_repr = torch.mean(sentence_tokens, dim=0) + sentence_reprs.append(sentence_repr) + + # If no valid sentences, skip this batch item + if not sentence_reprs: + sentence_preds.append(None) + continue + + # Stack and classify all sentence representations + if sentence_reprs: + stacked_reprs = torch.stack(sentence_reprs) + predictions = self.sentence_classifier(stacked_reprs) + sentence_preds.append(predictions) + else: + sentence_preds.append(None) + + return sentence_preds + + def get_config(self) -> dict[str, Any]: + """Get model configuration as a dictionary. + + :return: Model configuration + """ + return self.config.to_dict() + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, *model_args, **kwargs + ) -> "SentenceModel": + """Load a model from a pretrained model or path. + + This overrides the from_pretrained method from PreTrainedModel to handle + our specific model architecture. + + :param pretrained_model_name_or_path: Pretrained model name or path + :param model_args: Additional model arguments + :param kwargs: Additional keyword arguments + :return: SentencekModel instance + """ + # Let HuggingFace handle the downloading, caching, etc. + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) + + def save_pretrained(self, save_directory, **kwargs) -> None: + """Save the model to a directory. + This overrides the save_pretrained method from PreTrainedModel to handle + our specific model architecture. + + :param save_directory: Directory to save the model + :param kwargs: Additional keyword arguments + :return: None + """ + # Let HuggingFace's built-in method handle the saving + super().save_pretrained(save_directory, **kwargs) diff --git a/lettucedetect/models/trainer.py b/lettucedetect/models/trainer.py index 1ec97d7..1cb6bfe 100644 --- a/lettucedetect/models/trainer.py +++ b/lettucedetect/models/trainer.py @@ -1,14 +1,24 @@ +import json +import logging +import os import time from datetime import timedelta +from pathlib import Path import torch +import torch.nn as nn +import torch.optim as optim from torch.nn import Module +from torch.nn.utils.rnn import pad_sequence from torch.optim import Optimizer from torch.utils.data import DataLoader from tqdm.auto import tqdm from transformers import PreTrainedTokenizer -from lettucedetect.models.evaluator import evaluate_model, print_metrics +from lettucedetect.models.evaluator import evaluate_model, evaluate_sentence_model, print_metrics + +# Set up logger +logger = logging.getLogger(__name__) class Trainer: @@ -81,8 +91,10 @@ def train(self) -> float: attention_mask=batch["attention_mask"].to(self.device), labels=batch["labels"].to(self.device), ) + loss = outputs.loss loss.backward() + self.optimizer.step() total_loss += loss.item() @@ -118,3 +130,365 @@ def train(self) -> float: print(f"Best F1 score: {best_f1:.4f}") return best_f1 + + +def qa_collate_fn(batch: list[dict]) -> dict: + """batch is a list of N items (N = batch_size), each item is the dict returned by HallucinationDataset.__getitem__. + We need to pad input_ids and attention_mask to the max length in this batch. + + We'll keep: + - offset_mapping: list of lists + - sentence_boundaries: list of lists + - sentence_offset_mappings: list of lists + - labels: list of 1D tensors of shape [num_sentences] + """ + + if not batch: + logger.warning("Empty batch passed to qa_collate_fn") + return {} + + input_ids_list = [] + attention_mask_list = [] + offset_mappings = [] + sentence_boundaries = [] + sentence_offset_mappings = [] + labels_list = [] + + for item in batch: + try: + required_keys = [ + "input_ids", + "attention_mask", + "offset_mapping", + "sentence_boundaries", + "sentence_offset_mappings", + "labels", + ] + missing_keys = [k for k in required_keys if k not in item] + if missing_keys: + logger.warning(f"Item missing keys: {missing_keys}") + # Create empty tensors for missing keys + if "input_ids" not in item: + item["input_ids"] = torch.tensor([0], dtype=torch.long) + if "attention_mask" not in item: + item["attention_mask"] = torch.tensor([0], dtype=torch.long) + if "offset_mapping" not in item: + item["offset_mapping"] = [] + if "sentence_boundaries" not in item: + item["sentence_boundaries"] = [] + if "sentence_offset_mappings" not in item: + item["sentence_offset_mappings"] = [] + if "labels" not in item: + item["labels"] = torch.tensor([], dtype=torch.long) + + input_ids_list.append(item["input_ids"]) + attention_mask_list.append(item["attention_mask"]) + offset_mappings.append(item["offset_mapping"]) + sentence_boundaries.append(item["sentence_boundaries"]) + sentence_offset_mappings.append(item["sentence_offset_mappings"]) + labels_list.append(item["labels"]) + except Exception as e: + logger.error(f"Error processing item in collate_fn: {e}") + # Skip this item or add placeholder + # Add an empty placeholder to keep batch size consistent + input_ids_list.append(torch.tensor([0], dtype=torch.long)) + attention_mask_list.append(torch.tensor([0], dtype=torch.long)) + offset_mappings.append([]) + sentence_boundaries.append([]) + sentence_offset_mappings.append([]) + labels_list.append(torch.tensor([], dtype=torch.long)) + # If all lists are empty after processing, return empty dict + if not input_ids_list: + logger.warning("All items in batch were invalid") + return {} + + try: + # Pad input_ids and attention_mask + padded_input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=0) + padded_attention_mask = pad_sequence(attention_mask_list, batch_first=True, padding_value=0) + + return { + "input_ids": padded_input_ids, # [batch_size, max_seq_len_in_batch] + "attention_mask": padded_attention_mask, # [batch_size, max_seq_len_in_batch] + "offset_mapping": offset_mappings, # list of length batch_size + "sentence_boundaries": sentence_boundaries, # list of length batch_size + "sentence_offset_mappings": sentence_offset_mappings, + "labels": labels_list, # list of length batch_size (each is a 1D Tensor) + } + except Exception as e: + logger.error(f"Error padding sequences in collate_fn: {e}") + return { + "input_ids": torch.zeros((len(input_ids_list), 1), dtype=torch.long), + "attention_mask": torch.zeros((len(attention_mask_list), 1), dtype=torch.long), + "offset_mapping": [[] for _ in input_ids_list], + "sentence_boundaries": [[] for _ in input_ids_list], + "sentence_offset_mappings": [[] for _ in input_ids_list], + "labels": [torch.tensor([], dtype=torch.long) for _ in input_ids_list], + } + + +class SentenceTrainer: + def __init__( + self, + model: nn.Module, + train_loader: DataLoader, + test_loader: DataLoader, + tokenizer=None, + batch_size: int = 4, + learning_rate: float = 2e-5, + epochs: int = 6, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + save_path: str = None, + ) -> None: + """ + Simple trainer class for training a model on a dataset. + + :param model: The model to train + :param train_loader: DataLoader for training data + :param test_loader: DataLoader for test data + :param tokenizer: Tokenizer to save with the model + :param batch_size: The batch size + :param learning_rate: The learning rate + :param epochs: The number of epochs + :param device: The device to use + :param output_dir: Directory to save model checkpoints + """ + self.model = model + self.train_loader = train_loader + self.test_loader = test_loader + self.batch_size = batch_size + self.tokenizer = tokenizer + self.learning_rate = learning_rate + self.epochs = epochs + self.device = device + self.save_path = Path(save_path) if save_path else None + self.best_f1 = 0.0 + self.current_epoch = 0 + + self.criterion = nn.CrossEntropyLoss() + self.optimizer = optim.AdamW(self.model.parameters(), lr=self.learning_rate) + self.model.to(self.device) + + def _train_one_epoch(self) -> float: + """Train the model for one epoch. + + :return: Average loss for the epoch + """ + self.model.train() + total_loss = 0.0 + step_count = 0 + + progress_bar = tqdm(self.train_loader, desc="Training", leave=True) + + try: + for batch in progress_bar: + try: + input_ids = batch["input_ids"].to(self.device) + attention_mask = batch["attention_mask"].to(self.device) + sentence_boundaries = batch["sentence_boundaries"] + labels_list = batch["labels"] + + self.optimizer.zero_grad() + + # Forward pass + logits_list = self.model(input_ids, attention_mask, sentence_boundaries) + + # Compute loss + batch_loss = 0.0 + doc_count = 0 + for i, logits in enumerate(logits_list): + # Make sure we have labels and logits of the same length + if i >= len(labels_list): + logger.warning( + f"Mismatch between logits and labels list lengths: {len(logits_list)} vs {len(labels_list)}" + ) + continue + + labels_i = labels_list[i].to(self.device) # shape: [num_sentences_i] + + if logits.size(0) == 0: + # if no sentences in the document, skip + continue + + # Make sure we have enough labels for all logits + if logits.size(0) > labels_i.size(0): + logger.warning( + f"Mismatch between logits and labels sizes: {logits.size(0)} vs {labels_i.size(0)}" + ) + logits = logits[: labels_i.size(0), :] + + loss_i = self.criterion(logits, labels_i[: logits.size(0)]) + batch_loss += loss_i + doc_count += 1 + + if doc_count > 0: + # average the doc losses in the batch + batch_loss = batch_loss / doc_count + batch_loss.backward() + self.optimizer.step() + + total_loss += batch_loss.item() + step_count += 1 + + progress_bar.set_postfix( + { + "loss": f"{batch_loss.item():.4f}", + "avg_loss": f"{total_loss / step_count:.4f}" + if step_count > 0 + else "N/A", + } + ) + except RuntimeError as e: + if "out of memory" in str(e).lower(): + logger.error(f"CUDA out of memory error: {e}") + torch.cuda.empty_cache() + continue + else: + logger.error(f"Runtime error in training loop: {e}") + raise + + except Exception as e: + logger.error(f"Error processing batch: {e}") + continue + + except Exception as e: + logger.error(f"Error during training: {e}") + if step_count == 0: + # If we haven't successfully completed any steps, re-raise the error + raise + + if step_count == 0: + logger.warning("No steps completed in this epoch") + return 0.0 + + return total_loss / step_count + + def train(self) -> float: + """Train the model for multiple epochs. + + :return: Best F1 score achieved during training + """ + + start_time = time.time() + print(f"\nStarting training on {self.device}") + print(f"Training samples: {len(self.train_loader.dataset)}") + if self.test_loader: + print(f"Validation samples: {len(self.test_loader.dataset)}") + + for epoch in range(self.epochs): + self.current_epoch = epoch + 1 # Update current epoch + epoch_start = time.time() + print(f"\nEpoch {self.current_epoch}/{self.epochs}") + + # Train for one epoch + train_loss = self._train_one_epoch() + + epoch_time = time.time() - epoch_start + print( + f"Epoch {self.current_epoch} completed in {timedelta(seconds=int(epoch_time))}. Average loss: {train_loss:.4f}" + ) + if self.test_loader is not None: + print("\nEvaluating...") + metrics = evaluate_sentence_model( + self.model, self.test_loader, self.device, self.criterion, verbose=True + ) + print("Validation metrics:") + print_metrics(metrics) + + # Save metrics to a JSON file + if self.save_path: + metrics_path = self.save_path / "metrics.json" + metrics_list = [] + if os.path.exists(metrics_path) and os.path.getsize(metrics_path) > 0: + with open(metrics_path, "r") as f: + try: + metrics_list = json.load(f) + if not isinstance(metrics_list, list): + metrics_list = [metrics_list] + except json.JSONDecodeError: + metrics_list = [] + metrics_data = { + "loss": float(metrics["loss"]), + "hallucinated": { + "precision": float(metrics["hallucinated"]["precision"]), + "recall": float(metrics["hallucinated"]["recall"]), + "f1": float(metrics["hallucinated"]["f1"]), + }, + "supported": { + "precision": float(metrics["supported"]["precision"]), + "recall": float(metrics["supported"]["recall"]), + "f1": float(metrics["supported"]["f1"]), + }, + "auroc": float(metrics["auroc"]), + "accuracy": float(metrics["accuracy"]), + "epoch": self.current_epoch, + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + } + metrics_list.append(metrics_data) + + with open(metrics_path, "w") as f: + json.dump(metrics_list, f, indent=2) + + # Save best model based on F1 score + if self.save_path and metrics["hallucinated"]["f1"] > self.best_f1: + self.best_f1 = metrics["hallucinated"]["f1"] + # Save directly to output_dir with standard names + model_path = self.save_path + print(model_path) + self.save_model(model_path) + + # Save best metrics separately + best_metrics_path = self.save_path / "best_metrics_sentence_model.json" + with open(best_metrics_path, "w") as f: + json.dump(metrics_data, f, indent=2) + + print(f"New best model saved with F1: {self.best_f1:.4f}") + else: + print("No validation data provided, skipping evaluation.") + + # If no validation data, save the latest model after each epoch + if self.save_path: + model_path = self.save_path + self.save_model(model_path) + print(f"Model saved after epoch {self.current_epoch}") + + # Final training summary + total_time = time.time() - start_time + hours, remainder = divmod(int(total_time), 3600) + minutes, seconds = divmod(remainder, 60) + print(f"\nTraining completed in {hours:02}:{minutes:02}:{seconds:02}") + print( + f"Best validation F1: {self.best_f1:.4f}" + if self.best_f1 > 0 + else "No validation performed" + ) + + return self.best_f1 + + def save_model(self, save_path) -> None: + """Save the model to the given path with metadata. + + :param save_path: Path to save the model to + :return: None + """ + if isinstance(save_path, str) or isinstance(save_path, Path): + save_dir = Path(save_path) + if save_dir.suffix: + save_dir = save_dir.parent + else: + save_dir = save_path + + # Create metadata + metadata = { + "best_f1": float(self.best_f1), + "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), + "epochs_trained": self.current_epoch, + "learning_rate": self.learning_rate, + "batch_size": self.batch_size, + "device": str(self.device), + } + + # Use the model's save_pretrained method + self.model.save_pretrained(save_dir, tokenizer=self.tokenizer, metadata=metadata) + + return save_dir diff --git a/lettucedetect/preprocess/preprocess_ragbench.py b/lettucedetect/preprocess/preprocess_ragbench.py index 428f68d..1a9cbb7 100644 --- a/lettucedetect/preprocess/preprocess_ragbench.py +++ b/lettucedetect/preprocess/preprocess_ragbench.py @@ -86,7 +86,19 @@ def create_sample(response: dict, dataset_name: str, split: str) -> Hallucinatio ] labels = create_labels(response, hallucinations) - return HallucinationSample(prompt, answer, labels, split, dataset_name, "ragbench", "en") + mapping = { + "customer_support": ["delucionqa", "emanual", "techqa"], + "finance_numerical_reasoning": ["finqa", "tatqa"], + "biomed": ["pubmedqa", "covidqa"], + "legal": ["cuad"], + "general_knowledge": ["hotpotqa", "msmarco", "hagrid", "expertqa"], + } + task_type = next((k for k, v in mapping.items() if dataset_name in v), None) + answer_sentences = [sentence for _, sentence in response["response_sentences"]] + + return HallucinationSample( + prompt, answer, labels, split, task_type, "ragbench", "en", answer_sentences + ) def main(input_dir: str, output_dir: Path): diff --git a/pyproject.toml b/pyproject.toml index 8995fb5..5b5e535 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "scikit-learn>=1.6.1", "numpy>=2.2.2", "openai==1.66.3", + "nltk >= 3.9.1", ] [project.urls] diff --git a/scripts/evaluate_ragas.py b/scripts/evaluate_ragas.py index 3689a22..5e2b822 100644 --- a/scripts/evaluate_ragas.py +++ b/scripts/evaluate_ragas.py @@ -14,6 +14,7 @@ HallucinationData, HallucinationSample, ) +from lettucedetect.models.evaluator import print_metrics def evaluate_ragas( @@ -96,7 +97,7 @@ def evaluate_ragas( }, } results["auroc"] = auroc - + print_metrics(results) if verbose: report = classification_report( example_labels, @@ -138,6 +139,8 @@ def main(ground_truth_file: Path, ragas_baseline: Path, threshold): test_samples, task_type_map = load_data(ground_truth_file) test_samples_ragas, task_type_map_ragas = load_data(ragas_baseline) + print(len(test_samples)) + print(len(test_samples_ragas)) # Evaluate the whole dataset print("\nTask type: whole dataset") evaluate_ragas( @@ -146,15 +149,14 @@ def main(ground_truth_file: Path, ragas_baseline: Path, threshold): threshold=threshold, ) - for task_type, samples in task_type_map.items(): - for task_type_llm, samples_llm in task_type_map_ragas.items(): - print(task_type_llm) - print(f"\nTask type: {task_type_llm}") - evaluate_ragas( - test_samples, - test_samples_ragas, - threshold=threshold, - ) + for task_type_llm, samples_llm in task_type_map_ragas.items(): + print(task_type_llm) + print(f"\nTask type: {task_type_llm}") + evaluate_ragas( + test_samples, + test_samples_ragas, + threshold=threshold, + ) if __name__ == "__main__": diff --git a/scripts/evaluate_sentence_transformer.py b/scripts/evaluate_sentence_transformer.py new file mode 100644 index 0000000..ee2ad05 --- /dev/null +++ b/scripts/evaluate_sentence_transformer.py @@ -0,0 +1,88 @@ +import argparse +import json +from pathlib import Path + +import torch +from torch.utils.data import DataLoader +from transformers import ( + AutoModelForTokenClassification, + AutoTokenizer, + DataCollatorForTokenClassification, +) + +from lettucedetect.datasets.hallucination_dataset import ( + HallucinationData, + HallucinationDataset, +) +from lettucedetect.detectors.factory import * +from lettucedetect.models.evaluator import ( + evaluate_detector_example_level_batch, + print_metrics, +) +from lettucedetect.models.sentence_model import SentenceModel + + +def evaluate_task_samples_sentence( + samples, + detector=None, +): + print(f"\nEvaluating model on {len(samples)} samples") + print("\n---- Example-Level Span Evaluation ----") + metrics = evaluate_detector_example_level_batch(detector, samples) + print_metrics(metrics) + return metrics + + +def load_data(data_path): + data_path = Path(data_path) + hallucination_data = HallucinationData.from_json(json.loads(data_path.read_text())) + + # Filter test samples from the data + test_samples = [sample for sample in hallucination_data.samples if sample.split == "test"] + + # group samples by task type + task_type_map = {} + for sample in test_samples: + if sample.task_type not in task_type_map: + task_type_map[sample.task_type] = [] + task_type_map[sample.task_type].append(sample) + return test_samples, task_type_map + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate a hallucination detection model") + parser.add_argument("--model_path", type=str, required=True, help="Path to the saved model") + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to the evaluation data (JSON format)", + ) + + args = parser.parse_args() + + test_samples, task_type_map = load_data(args.data_path) + + print(f"\nEvaluating model on test samples: {len(test_samples)}") + detector = make_detector( + method="sentencetransformer", + model_path=args.model_path, + ) + + # Evaluate the whole dataset + print("\nTask type: whole dataset") + evaluate_task_samples_sentence( + test_samples, + detector=detector, + ) + + for task_type, samples in task_type_map.items(): + print(f"\nTask type: {task_type}") + evaluate_task_samples_sentence( + samples, + detector=detector, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/ragas_baseline.py b/scripts/ragas_baseline.py index d7b01ae..d7645e9 100644 --- a/scripts/ragas_baseline.py +++ b/scripts/ragas_baseline.py @@ -47,6 +47,7 @@ def evaluate_metrics(sample, llm): results["faithfulness"] = metric.single_turn_score(sample) except Exception as e: results["faithfulness"] = f"Error: {e}" + print(results) return results @@ -96,20 +97,21 @@ def main( hallucination_data = HallucinationData.from_json(json.loads(input_file.read_text())) samples = [sample for sample in hallucination_data.samples if sample.split == "test"] - + print(len(samples)) hallucination_data_ragas = load_check_existing_data(output_file=output_file) num_processed = len(hallucination_data_ragas.samples) total_samples = len(hallucination_data_ragas.samples) - + print(num_processed) llm = LangchainLLMWrapper( - ChatOpenAI(model="gpt-4o", openai_api_key=get_api_key(), temperature=0) + ChatOpenAI(model="gpt-4o-mini", openai_api_key=get_api_key(), temperature=0) ) - for i, sample in enumerate(samples, start=num_processed): + samples_to_process = samples[num_processed:] + for i, sample in enumerate(samples_to_process, start=num_processed): print("--------", i, "--------") sample_ragas = create_sample_baseline(sample, llm) hallucination_data_ragas.samples.append(sample_ragas) - if i % 50 == 0 or i == total_samples - 1: + if (i + 1) % 50 == 0 or (i + 1) == total_samples: (output_file).write_text(json.dumps(hallucination_data_ragas.to_json(), indent=4)) diff --git a/scripts/train.py b/scripts/train.py index 7dca4f2..300454e 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -17,7 +17,8 @@ HallucinationDataset, HallucinationSample, ) -from lettucedetect.models.trainer import Trainer +from lettucedetect.models.sentence_model import SentenceModel +from lettucedetect.models.trainer import SentenceTrainer, Trainer, qa_collate_fn def set_seed(seed: int = 42): @@ -68,6 +69,12 @@ def parse_args(): parser.add_argument( "--learning-rate", type=float, default=1e-5, help="Learning rate for training" ) + parser.add_argument( + "--method", + type=str, + default="trasformer", + help="Do you want to train a token(transformer) or sentence level model (sentencetransformer)?", + ) return parser.parse_args() @@ -116,10 +123,14 @@ def main(): dev_samples.extend(ragbench_dev_samples) tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True) - data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100) + data_collator = ( + DataCollatorForTokenClassification(tokenizer=tokenizer, label_pad_token_id=-100) + if args.method == "transformer" + else qa_collate_fn + ) - train_dataset = HallucinationDataset(train_samples, tokenizer) - dev_dataset = HallucinationDataset(dev_samples, tokenizer) + train_dataset = HallucinationDataset(train_samples, tokenizer, method=args.method) + dev_dataset = HallucinationDataset(dev_samples, tokenizer, method=args.method) train_loader = DataLoader( train_dataset, @@ -134,19 +145,34 @@ def main(): collate_fn=data_collator, ) - model = AutoModelForTokenClassification.from_pretrained( - args.model_name, num_labels=2, trust_remote_code=True - ) - - trainer = Trainer( - model=model, - tokenizer=tokenizer, - train_loader=train_loader, - test_loader=dev_loader, - epochs=args.epochs, - learning_rate=args.learning_rate, - save_path=args.output_dir, - ) + if args.method == "transformer": + model = AutoModelForTokenClassification.from_pretrained( + args.model_name, num_labels=2, trust_remote_code=True + ) + + trainer = Trainer( + model=model, + tokenizer=tokenizer, + train_loader=train_loader, + test_loader=dev_loader, + epochs=args.epochs, + learning_rate=args.learning_rate, + save_path=args.output_dir, + ) + else: + model = SentenceModel(model_name=args.model_name) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + trainer = SentenceTrainer( + model=model, + tokenizer=tokenizer, + train_loader=train_loader, + test_loader=dev_loader, + batch_size=args.batch_size, + epochs=args.epochs, + learning_rate=args.learning_rate, + save_path=args.output_dir, + ) trainer.train()