diff --git a/src/megatron/bridge/data/builders/hf_dataset.py b/src/megatron/bridge/data/builders/hf_dataset.py index 6b90bd129a..28d96d34c6 100644 --- a/src/megatron/bridge/data/builders/hf_dataset.py +++ b/src/megatron/bridge/data/builders/hf_dataset.py @@ -36,7 +36,11 @@ class ProcessExampleOutput(TypedDict): - """Expected output structure from a `ProcessExampleFn`.""" + """Legacy output structure for input/output-style processors. + + Process functions may return this format or any other ``dict[str, Any]`` + (e.g. ``{"messages": [...], "tools": [...]}`` for chat datasets). + """ input: str output: str @@ -44,11 +48,14 @@ class ProcessExampleOutput(TypedDict): class ProcessExampleFn(Protocol): - """Protocol defining the signature for a function that processes a single dataset example.""" + """Protocol defining the signature for a function that processes a single dataset example. - def __call__( - self, example: dict[str, Any], tokenizer: Optional[MegatronTokenizer] = None - ) -> ProcessExampleOutput: ... + The returned dict is written directly as a JSONL line, so the keys must + match what the downstream dataset class expects (e.g. ``input``/``output`` + for ``GPTSFTDataset``, or ``messages``/``tools`` for ``GPTSFTChatDataset``). + """ + + def __call__(self, example: dict[str, Any], tokenizer: MegatronTokenizer | None = None) -> dict[str, Any]: ... @dataclass(kw_only=True) @@ -197,15 +204,8 @@ def preprocess_and_split_data( with output_file.open("w", encoding="utf-8") as f: for example in tqdm(dataset, desc=f"Processing {split_name} split"): - json_line = {} - processed_example = process_example_fn(example, tokenizer) - # Write each example as a JSON line in the output file - json_line["input"] = processed_example["input"] - json_line["output"] = processed_example["output"] - if split_name == "test": - json_line["original_answers"] = processed_example["original_answers"] - f.write(json.dumps(json_line) + "\n") + f.write(json.dumps(processed_example) + "\n") logger.info(f"{split_name} split saved to {output_file}") diff --git a/tests/unit_tests/data/test_hf_dataset_chat_format.py b/tests/unit_tests/data/test_hf_dataset_chat_format.py new file mode 100644 index 0000000000..de93942fa1 --- /dev/null +++ b/tests/unit_tests/data/test_hf_dataset_chat_format.py @@ -0,0 +1,113 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from datasets import Dataset, DatasetDict + +from megatron.bridge.data.builders.hf_dataset import preprocess_and_split_data + + +CHAT_EXAMPLES = [ + { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ], + "tools": [{"name": "search", "description": "Search the web"}], + }, + { + "messages": [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "4"}, + ], + "tools": [], + }, +] + + +def process_chat_example(example, tokenizer=None): + """Extract messages and tools from a chat-format example.""" + return {"messages": example["messages"], "tools": example["tools"]} + + +def process_input_output_example(example, tokenizer=None): + """Extract input/output from chat messages for backward-compat testing.""" + return { + "input": example["messages"][0]["content"], + "output": example["messages"][1]["content"], + } + + +class TestPreprocessChatFormat: + """Tests that preprocess_and_split_data writes arbitrary process_example_fn output to JSONL.""" + + def test_chat_format_writes_messages_and_tools(self, tmp_path): + """Chat-format process functions should write messages/tools keys to JSONL.""" + dset = DatasetDict({"train": Dataset.from_list(CHAT_EXAMPLES)}) + + preprocess_and_split_data( + dset=dset, + dataset_name="test-chat", + dataset_root=tmp_path, + tokenizer=None, + process_example_fn=process_chat_example, + val_proportion=None, + do_validation=False, + do_test=False, + rewrite=True, + ) + + output_file = tmp_path / "training.jsonl" + assert output_file.exists() + + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 2 + + for i, line in enumerate(lines): + data = json.loads(line) + assert "messages" in data, f"Line {i} missing 'messages' key" + assert "tools" in data, f"Line {i} missing 'tools' key" + assert data["messages"] == CHAT_EXAMPLES[i]["messages"] + assert data["tools"] == CHAT_EXAMPLES[i]["tools"] + assert "input" not in data + assert "output" not in data + + def test_input_output_format_still_works(self, tmp_path): + """Backward compat: input/output process functions should still produce correct JSONL.""" + dset = DatasetDict({"train": Dataset.from_list(CHAT_EXAMPLES)}) + + preprocess_and_split_data( + dset=dset, + dataset_name="test-io", + dataset_root=tmp_path, + tokenizer=None, + process_example_fn=process_input_output_example, + val_proportion=None, + do_validation=False, + do_test=False, + rewrite=True, + ) + + output_file = tmp_path / "training.jsonl" + assert output_file.exists() + + lines = output_file.read_text().strip().split("\n") + assert len(lines) == 2 + + for line in lines: + data = json.loads(line) + assert "input" in data + assert "output" in data + assert "messages" not in data