Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/megatron/bridge/data/builders/hf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,26 @@


class ProcessExampleOutput(TypedDict):
"""Expected output structure from a `ProcessExampleFn`."""
"""Legacy output structure for input/output-style processors.
Process functions may return this format or any other ``dict[str, Any]``
(e.g. ``{"messages": [...], "tools": [...]}`` for chat datasets).
"""

input: str
output: str
original_answers: list[str]


class ProcessExampleFn(Protocol):
"""Protocol defining the signature for a function that processes a single dataset example."""
"""Protocol defining the signature for a function that processes a single dataset example.
def __call__(
self, example: dict[str, Any], tokenizer: Optional[MegatronTokenizer] = None
) -> ProcessExampleOutput: ...
The returned dict is written directly as a JSONL line, so the keys must
match what the downstream dataset class expects (e.g. ``input``/``output``
for ``GPTSFTDataset``, or ``messages``/``tools`` for ``GPTSFTChatDataset``).
"""

def __call__(self, example: dict[str, Any], tokenizer: MegatronTokenizer | None = None) -> dict[str, Any]: ...


@dataclass(kw_only=True)
Expand Down Expand Up @@ -197,15 +204,8 @@ def preprocess_and_split_data(

with output_file.open("w", encoding="utf-8") as f:
for example in tqdm(dataset, desc=f"Processing {split_name} split"):
json_line = {}

processed_example = process_example_fn(example, tokenizer)
# Write each example as a JSON line in the output file
json_line["input"] = processed_example["input"]
json_line["output"] = processed_example["output"]
if split_name == "test":
json_line["original_answers"] = processed_example["original_answers"]
f.write(json.dumps(json_line) + "\n")
f.write(json.dumps(processed_example) + "\n")

logger.info(f"{split_name} split saved to {output_file}")

Expand Down
113 changes: 113 additions & 0 deletions tests/unit_tests/data/test_hf_dataset_chat_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json

from datasets import Dataset, DatasetDict

from megatron.bridge.data.builders.hf_dataset import preprocess_and_split_data


CHAT_EXAMPLES = [
{
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
],
"tools": [{"name": "search", "description": "Search the web"}],
},
{
"messages": [
{"role": "user", "content": "What is 2+2?"},
{"role": "assistant", "content": "4"},
],
"tools": [],
},
]


def process_chat_example(example, tokenizer=None):
"""Extract messages and tools from a chat-format example."""
return {"messages": example["messages"], "tools": example["tools"]}


def process_input_output_example(example, tokenizer=None):
"""Extract input/output from chat messages for backward-compat testing."""
return {
"input": example["messages"][0]["content"],
"output": example["messages"][1]["content"],
}


class TestPreprocessChatFormat:
"""Tests that preprocess_and_split_data writes arbitrary process_example_fn output to JSONL."""

def test_chat_format_writes_messages_and_tools(self, tmp_path):
"""Chat-format process functions should write messages/tools keys to JSONL."""
dset = DatasetDict({"train": Dataset.from_list(CHAT_EXAMPLES)})

preprocess_and_split_data(
dset=dset,
dataset_name="test-chat",
dataset_root=tmp_path,
tokenizer=None,
process_example_fn=process_chat_example,
val_proportion=None,
do_validation=False,
do_test=False,
rewrite=True,
)

output_file = tmp_path / "training.jsonl"
assert output_file.exists()

lines = output_file.read_text().strip().split("\n")
assert len(lines) == 2

for i, line in enumerate(lines):
data = json.loads(line)
assert "messages" in data, f"Line {i} missing 'messages' key"
assert "tools" in data, f"Line {i} missing 'tools' key"
assert data["messages"] == CHAT_EXAMPLES[i]["messages"]
assert data["tools"] == CHAT_EXAMPLES[i]["tools"]
assert "input" not in data
assert "output" not in data

def test_input_output_format_still_works(self, tmp_path):
"""Backward compat: input/output process functions should still produce correct JSONL."""
dset = DatasetDict({"train": Dataset.from_list(CHAT_EXAMPLES)})

preprocess_and_split_data(
dset=dset,
dataset_name="test-io",
dataset_root=tmp_path,
tokenizer=None,
process_example_fn=process_input_output_example,
val_proportion=None,
do_validation=False,
do_test=False,
rewrite=True,
)

output_file = tmp_path / "training.jsonl"
assert output_file.exists()

lines = output_file.read_text().strip().split("\n")
assert len(lines) == 2

for line in lines:
data = json.loads(line)
assert "input" in data
assert "output" in data
assert "messages" not in data