diff --git a/flexeval/core/chat_dataset/__init__.py b/flexeval/core/chat_dataset/__init__.py index be31bd26..9192142b 100644 --- a/flexeval/core/chat_dataset/__init__.py +++ b/flexeval/core/chat_dataset/__init__.py @@ -2,4 +2,4 @@ from .chatbot_bench import ChatbotBench from .openai_messages import OpenAIMessagesDataset from .sacrebleu_dataset import SacreBleuChatDataset -from .template_based import HFChatDataset, JsonlChatDataset, TemplateChatDataset, load_jinja2_template +from .template_based import HFChatDataset, JsonlChatDataset, Preprocessor, TemplateChatDataset, load_jinja2_template diff --git a/flexeval/core/chat_dataset/template_based.py b/flexeval/core/chat_dataset/template_based.py index 55b6a3ee..c6a4de0b 100644 --- a/flexeval/core/chat_dataset/template_based.py +++ b/flexeval/core/chat_dataset/template_based.py @@ -1,10 +1,11 @@ from __future__ import annotations +import abc import json from ast import literal_eval from os import PathLike from pathlib import Path -from typing import Any +from typing import Any, Literal import datasets from jinja2 import Template @@ -15,6 +16,13 @@ from .base import ChatDataset, ChatInstance +class Preprocessor(abc.ABC): + # An abstract base class for preprocessors + @abc.abstractmethod + def __call__(self, item: dict[str, Any]) -> dict[str, Any]: + pass + + def load_jinja2_template(template: str | PathLike[str]) -> Template: path = Path(template) try: @@ -50,6 +58,9 @@ class TemplateChatDataset(ChatDataset): The key is a Jinja2 template string to embed the item into a string, and the value is the value to keep. remove_conditions: A dictionary to indicate the condition to remove certain items. The key is a Jinja2 template string to embed the item into a string, and the value is the value to remove. + parse_input_utterance: If specified, parse the rendered `input_utterance` string using the given method, + `ast.literal_eval` if "literal_eval" or `json.loads` if "json_loads". If None, do not parse. + preprocessors: A list of Preprocessor instances to preprocess each item. """ def __init__( @@ -64,6 +75,8 @@ def __init__( data_range: tuple[int, int] | None = None, keep_conditions: dict[str, str] | None = None, remove_conditions: dict[str, str] | None = None, + parse_input_utterance: Literal["literal_eval", "json_loads"] | None = None, + preprocessors: list[Preprocessor] | None = None, ) -> None: if reference_template and reference_list_template: msg = "Only one of reference_template and reference_list_template can be set." @@ -100,12 +113,22 @@ def __init__( load_jinja2_template(system_message_template) if system_message_template else None ) + self.parse_input_utterance = parse_input_utterance + self.preprocessors = preprocessors + def __len__(self) -> int: return len(self.items) def __getitem__(self, i: int) -> ChatInstance: item = self.items[i] + if self.preprocessors: + for preprocessor in self.preprocessors: + item = preprocessor(item) input_utterance = self.input_template.render(**item) + if self.parse_input_utterance == "literal_eval": + input_utterance = literal_eval(input_utterance) + elif self.parse_input_utterance == "json_loads": + input_utterance = json.loads(input_utterance, strict=False) messages = [{"role": "user", "content": input_utterance}] if self._system_message_template: @@ -166,6 +189,8 @@ def __init__( data_range: tuple[int, int] | None = None, keep_conditions: dict[str, str] | None = None, remove_conditions: dict[str, str] | None = None, + parse_input_utterance: Literal["literal_eval", "json_loads"] | None = None, + preprocessors: list[Preprocessor] | None = None, ) -> None: dataset_kwargs = dataset_kwargs or {} dataset = datasets.load_dataset(path, name=subset, split=split, **dataset_kwargs) @@ -182,6 +207,8 @@ def __init__( data_range=data_range, keep_conditions=keep_conditions, remove_conditions=remove_conditions, + parse_input_utterance=parse_input_utterance, + preprocessors=preprocessors, ) @@ -205,6 +232,8 @@ def __init__( data_range: tuple[int, int] | None = None, keep_conditions: dict[str, str] | None = None, remove_conditions: dict[str, str] | None = None, + parse_input_utterance: Literal["literal_eval", "json_loads"] | None = None, + preprocessors: list[Preprocessor] | None = None, ) -> None: with open(path) as f: items = [json.loads(line) for line in f] @@ -220,4 +249,6 @@ def __init__( data_range=data_range, keep_conditions=keep_conditions, remove_conditions=remove_conditions, + parse_input_utterance=parse_input_utterance, + preprocessors=preprocessors, ) diff --git a/tests/core/chat_dataset/test_template_based.py b/tests/core/chat_dataset/test_template_based.py index 942eea2d..fbcc3f23 100644 --- a/tests/core/chat_dataset/test_template_based.py +++ b/tests/core/chat_dataset/test_template_based.py @@ -226,3 +226,62 @@ def test_load_jinja2_template(dummy_template_file: Path) -> None: embed_result = template_from_string.render() assert isinstance(template_from_string, Template) assert embed_result == "a" * 1000 + + +@pytest.mark.parametrize( + "parse_input_utterance", + ["literal_eval", "json_loads", None], +) +def test_parse_input_utterance(parse_input_utterance: str) -> None: + input_template = """[ + {"type": "image_url", "image_url": {"url": "{{ image }}"}}, + {"type": "text", "text": "{{ question }}"} + ]""" + dataset = TemplateChatDataset( + items=[ + { + "question": "Describe the color of this object.", + "answer": "red", + "image": "http://example.com/image1.jpg", + }, + ], + input_template=input_template, + parse_input_utterance=parse_input_utterance, + ) + + input_utterance = dataset[0].messages[0]["content"] + + if parse_input_utterance is None: + assert isinstance(input_utterance, str) + + else: + assert isinstance(input_utterance, list) + assert input_utterance[0]["type"] == "image_url" + assert input_utterance[0]["image_url"]["url"] == "http://example.com/image1.jpg" + assert input_utterance[1]["type"] == "text" + assert input_utterance[1]["text"] == "Describe the color of this object." + + +def test_preprocessors() -> None: + from flexeval.core.chat_dataset.template_based import Preprocessor + + class ToBase64(Preprocessor): + def __call__(self, item: dict) -> dict: + image = item["image"] # noqa: F841 # simulate using the image for conversion + item["image_base64"] = "data:image/jpeg;base64,..." + return item + + input_template = "{{ image_base64 }}" + dataset = TemplateChatDataset( + items=[ + { + "question": "Describe the color of this object.", + "answer": "red", + "image": "http://example.com/image1.jpg", + }, + ], + input_template=input_template, + preprocessors=[ToBase64()], + ) + input_utterance = dataset[0].messages[0]["content"] + assert input_utterance == "data:image/jpeg;base64,..."