Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flexeval/core/chat_dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .chatbot_bench import ChatbotBench
from .openai_messages import OpenAIMessagesDataset
from .sacrebleu_dataset import SacreBleuChatDataset
from .template_based import HFChatDataset, JsonlChatDataset, TemplateChatDataset, load_jinja2_template
from .template_based import HFChatDataset, JsonlChatDataset, Preprocessor, TemplateChatDataset, load_jinja2_template
33 changes: 32 additions & 1 deletion flexeval/core/chat_dataset/template_based.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import abc
import json
from ast import literal_eval
from os import PathLike
from pathlib import Path
from typing import Any
from typing import Any, Literal

import datasets
from jinja2 import Template
Expand All @@ -15,6 +16,13 @@
from .base import ChatDataset, ChatInstance


class Preprocessor(abc.ABC):
# An abstract base class for preprocessors
@abc.abstractmethod
def __call__(self, item: dict[str, Any]) -> dict[str, Any]:
pass


def load_jinja2_template(template: str | PathLike[str]) -> Template:
path = Path(template)
try:
Expand Down Expand Up @@ -50,6 +58,9 @@ class TemplateChatDataset(ChatDataset):
The key is a Jinja2 template string to embed the item into a string, and the value is the value to keep.
remove_conditions: A dictionary to indicate the condition to remove certain items.
The key is a Jinja2 template string to embed the item into a string, and the value is the value to remove.
parse_input_utterance: If specified, parse the rendered `input_utterance` string using the given method,
`ast.literal_eval` if "literal_eval" or `json.loads` if "json_loads". If None, do not parse.
preprocessors: A list of Preprocessor instances to preprocess each item.
"""

def __init__(
Expand All @@ -64,6 +75,8 @@ def __init__(
data_range: tuple[int, int] | None = None,
keep_conditions: dict[str, str] | None = None,
remove_conditions: dict[str, str] | None = None,
parse_input_utterance: Literal["literal_eval", "json_loads"] | None = None,
preprocessors: list[Preprocessor] | None = None,
) -> None:
if reference_template and reference_list_template:
msg = "Only one of reference_template and reference_list_template can be set."
Expand Down Expand Up @@ -100,12 +113,22 @@ def __init__(
load_jinja2_template(system_message_template) if system_message_template else None
)

self.parse_input_utterance = parse_input_utterance
self.preprocessors = preprocessors

def __len__(self) -> int:
return len(self.items)

def __getitem__(self, i: int) -> ChatInstance:
item = self.items[i]
if self.preprocessors:
for preprocessor in self.preprocessors:
item = preprocessor(item)
input_utterance = self.input_template.render(**item)
if self.parse_input_utterance == "literal_eval":
input_utterance = literal_eval(input_utterance)
elif self.parse_input_utterance == "json_loads":
input_utterance = json.loads(input_utterance, strict=False)
messages = [{"role": "user", "content": input_utterance}]

if self._system_message_template:
Expand Down Expand Up @@ -166,6 +189,8 @@ def __init__(
data_range: tuple[int, int] | None = None,
keep_conditions: dict[str, str] | None = None,
remove_conditions: dict[str, str] | None = None,
parse_input_utterance: Literal["literal_eval", "json_loads"] | None = None,
preprocessors: list[Preprocessor] | None = None,
) -> None:
dataset_kwargs = dataset_kwargs or {}
dataset = datasets.load_dataset(path, name=subset, split=split, **dataset_kwargs)
Expand All @@ -182,6 +207,8 @@ def __init__(
data_range=data_range,
keep_conditions=keep_conditions,
remove_conditions=remove_conditions,
parse_input_utterance=parse_input_utterance,
preprocessors=preprocessors,
)


Expand All @@ -205,6 +232,8 @@ def __init__(
data_range: tuple[int, int] | None = None,
keep_conditions: dict[str, str] | None = None,
remove_conditions: dict[str, str] | None = None,
parse_input_utterance: Literal["literal_eval", "json_loads"] | None = None,
preprocessors: list[Preprocessor] | None = None,
) -> None:
with open(path) as f:
items = [json.loads(line) for line in f]
Expand All @@ -220,4 +249,6 @@ def __init__(
data_range=data_range,
keep_conditions=keep_conditions,
remove_conditions=remove_conditions,
parse_input_utterance=parse_input_utterance,
preprocessors=preprocessors,
)
59 changes: 59 additions & 0 deletions tests/core/chat_dataset/test_template_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,62 @@ def test_load_jinja2_template(dummy_template_file: Path) -> None:
embed_result = template_from_string.render()
assert isinstance(template_from_string, Template)
assert embed_result == "a" * 1000


@pytest.mark.parametrize(
"parse_input_utterance",
["literal_eval", "json_loads", None],
)
def test_parse_input_utterance(parse_input_utterance: str) -> None:
input_template = """[
{"type": "image_url", "image_url": {"url": "{{ image }}"}},
{"type": "text", "text": "{{ question }}"}
]"""
dataset = TemplateChatDataset(
items=[
{
"question": "Describe the color of this object.",
"answer": "red",
"image": "http://example.com/image1.jpg",
},
],
input_template=input_template,
parse_input_utterance=parse_input_utterance,
)

input_utterance = dataset[0].messages[0]["content"]

if parse_input_utterance is None:
assert isinstance(input_utterance, str)

else:
assert isinstance(input_utterance, list)
assert input_utterance[0]["type"] == "image_url"
assert input_utterance[0]["image_url"]["url"] == "http://example.com/image1.jpg"
assert input_utterance[1]["type"] == "text"
assert input_utterance[1]["text"] == "Describe the color of this object."


def test_preprocessors() -> None:
from flexeval.core.chat_dataset.template_based import Preprocessor

class ToBase64(Preprocessor):
def __call__(self, item: dict) -> dict:
image = item["image"] # noqa: F841 # simulate using the image for conversion
item["image_base64"] = "data:image/jpeg;base64,..."
return item

input_template = "{{ image_base64 }}"
dataset = TemplateChatDataset(
items=[
{
"question": "Describe the color of this object.",
"answer": "red",
"image": "http://example.com/image1.jpg",
},
],
input_template=input_template,
preprocessors=[ToBase64()],
)
input_utterance = dataset[0].messages[0]["content"]
assert input_utterance == "data:image/jpeg;base64,..."