diff --git a/miles/rollout/data_source.py b/miles/rollout/data_source.py index 613319d34..155404913 100644 --- a/miles/rollout/data_source.py +++ b/miles/rollout/data_source.py @@ -75,6 +75,7 @@ def __init__(self, args): apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, seed=args.rollout_seed, + dataset_num_proc=args.dataset_num_proc, ) if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) @@ -85,15 +86,15 @@ def get_samples(self, num_samples): # TODO further improve code if self.dataset is not None: if self.sample_offset + num_samples <= len(self.dataset): - prompt_samples = self.dataset.samples[self.sample_offset : self.sample_offset + num_samples] + prompt_samples = [self.dataset[i] for i in range(self.sample_offset, self.sample_offset + num_samples)] self.sample_offset += num_samples else: - prompt_samples = self.dataset.samples[self.sample_offset :] + prompt_samples = [self.dataset[i] for i in range(self.sample_offset, len(self.dataset))] num_samples -= len(prompt_samples) self.epoch_id += 1 if self.args.rollout_shuffle: self.dataset.shuffle(self.epoch_id) - prompt_samples += self.dataset.samples[:num_samples] + prompt_samples += [self.dataset[i] for i in range(num_samples)] self.sample_offset = num_samples else: prompt_samples = [Sample() for _ in range(num_samples)] diff --git a/miles/rollout/sglang_rollout.py b/miles/rollout/sglang_rollout.py index 2f3734657..84e443882 100644 --- a/miles/rollout/sglang_rollout.py +++ b/miles/rollout/sglang_rollout.py @@ -509,6 +509,7 @@ async def eval_rollout_single_dataset( tool_key=dataset_cfg.tool_key, apply_chat_template=args.apply_chat_template, apply_chat_template_kwargs=args.apply_chat_template_kwargs, + dataset_num_proc=args.dataset_num_proc, ) dataset = EVAL_PROMPT_DATASET[cache_key] @@ -527,7 +528,8 @@ async def eval_rollout_single_dataset( tasks = [] # do multiple samples for eval prompts sample_index = 0 - for _i, prompt_sample in enumerate(dataset.samples): + for i in range(len(dataset)): + prompt_sample = dataset[i] for j in range(dataset_cfg.n_samples_per_eval_prompt): # use the same prompt for multiple samples sample = copy.deepcopy(prompt_sample) diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 824b3a028..a6e68122b 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -615,6 +615,12 @@ def add_data_arguments(parser): "and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. " ), ) + parser.add_argument( + "--dataset-num-proc", + type=int, + default=8, + help="Number of processes for dataset initialization and filtering.", + ) return parser def add_eval_arguments(parser): diff --git a/miles/utils/data.py b/miles/utils/data.py index fea0d4c46..4f00438c2 100644 --- a/miles/utils/data.py +++ b/miles/utils/data.py @@ -2,9 +2,10 @@ import json import logging import os -import random import re +from functools import partial +import datasets import numpy as np import ray @@ -21,6 +22,55 @@ logger = logging.getLogger(__name__) +_FILE_TYPE_MAP = { + ".jsonl": "json", + ".parquet": "parquet", +} + + +def _filter_func( + example, + tokenizer, + processor, + max_length, + prompt_key, + multimodal_keys, + apply_chat_template, + apply_chat_template_kwargs, + tool_key, +): + as_conversation = apply_chat_template + prompt = _build_messages(example, prompt_key, as_conversation, multimodal_keys) + + tools = None + if tool_key is not None and tool_key in example: + tools = example[tool_key] + if isinstance(tools, str): + tools = json.loads(tools) + elif isinstance(tools, np.ndarray): + tools = tools.tolist() + assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" + + if apply_chat_template: + formatted_prompt = tokenizer.apply_chat_template( + prompt, + tools=tools, + tokenize=False, + add_generation_prompt=True, + **(apply_chat_template_kwargs or {}), + ) + else: + formatted_prompt = prompt + + if processor: + from miles.utils.processing_utils import process_vision_info + + multimodal_inputs = process_vision_info(prompt, processor) + else: + multimodal_inputs = None + + return not _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs) + def read_file(path): path, row_slice = _parse_generalized_path(path) @@ -161,77 +211,131 @@ def __init__( seed=42, apply_chat_template=False, apply_chat_template_kwargs=None, + dataset_num_proc=8, ): - self.origin_samples = [] - for data in read_file(path): - as_conversation = apply_chat_template - prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys) - - metadata = data.get(metadata_key) or {} - tools = None - if tool_key is not None and tool_key in data: - tools = data[tool_key] - if isinstance(tools, str): - tools = json.loads(tools) - elif isinstance(tools, np.ndarray): - tools = tools.tolist() - assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" - metadata["tools"] = tools - - if apply_chat_template: - formatted_prompt = tokenizer.apply_chat_template( - prompt, - tools=tools, - tokenize=False, - add_generation_prompt=True, - **(apply_chat_template_kwargs or {}), - ) - else: - formatted_prompt = prompt + # 1. Store basic config + self.tokenizer = tokenizer + self.processor = processor + self.max_length = max_length + self.prompt_key = prompt_key + self.multimodal_keys = multimodal_keys + self.label_key = label_key + self.tool_key = tool_key + self.metadata_key = metadata_key + self.apply_chat_template = apply_chat_template + self.apply_chat_template_kwargs = apply_chat_template_kwargs or {} + self.seed = seed + self.epoch_id = -1 - if processor: - from miles.utils.processing_utils import process_vision_info + # 2. Load and process dataset + self.hf_dataset = self._load_and_filter_dataset(path, dataset_num_proc) + self.origin_hf_dataset = self.hf_dataset - assert isinstance( - prompt, list - ), f"prompt must be a list when processor is not None, got {type(prompt)} instead" - multimodal_inputs = process_vision_info(prompt, processor) - else: - multimodal_inputs = None + def _get_file_type(self, path: str) -> str: + _, ext = os.path.splitext(path) - # TODO: this is slow. - if _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs): - continue + try: + return _FILE_TYPE_MAP[ext] + except KeyError: + raise ValueError(f"Unsupported format: {ext}. Supported: {list(_FILE_TYPE_MAP.keys())}") from None - self.origin_samples.append( - Sample( - prompt=formatted_prompt, - label=data[label_key] if label_key is not None else None, - metadata=metadata, - multimodal_inputs=multimodal_inputs, - ) + def _load_and_filter_dataset(self, path, dataset_num_proc): + raw_file_path, row_slice = _parse_generalized_path(path) + + if not os.path.exists(raw_file_path): + raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.") + + logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.") + + # Determine file type and load using datasets library for memory-mapped access + file_type = self._get_file_type(raw_file_path) + ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train") + + # Apply row slicing if specified + if row_slice: + num_rows = len(ds) + indices = range(num_rows)[row_slice] + ds = ds.select(indices) + logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}") + + filter_kwargs = { + "tokenizer": self.tokenizer, + "processor": self.processor, + "max_length": self.max_length, + "prompt_key": self.prompt_key, + "multimodal_keys": self.multimodal_keys, + "apply_chat_template": self.apply_chat_template, + "apply_chat_template_kwargs": self.apply_chat_template_kwargs, + "tool_key": self.tool_key, + } + + original_size = len(ds) + + ds = ds.filter( + partial(_filter_func, **filter_kwargs), num_proc=dataset_num_proc, desc="Filtering invalid samples" + ) + + new_size = len(ds) + logger.info(f"Filtered dataset from {original_size} to {new_size} samples.") + + return ds + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, idx): + # The underlying HF dataset handles lazy fetching + data = self.hf_dataset[idx] + + # Process the data using existing logic + as_conversation = self.apply_chat_template + prompt = _build_messages(data, self.prompt_key, as_conversation, self.multimodal_keys) + + metadata = data.get(self.metadata_key) or {} + tools = None + if self.tool_key is not None and self.tool_key in data: + tools = data[self.tool_key] + if isinstance(tools, str): + tools = json.loads(tools) + # TODO (chenyang): If the JSON parsing is heavy, we might need + # to use hf_dataset.map() during init to pre-process these + # fields into a more efficient format (Arrow-native), rather + # than parsing raw strings on the fly. + elif isinstance(tools, np.ndarray): + tools = tools.tolist() + assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead" + metadata["tools"] = tools + + if self.apply_chat_template: + formatted_prompt = self.tokenizer.apply_chat_template( + prompt, tools=tools, tokenize=False, add_generation_prompt=True, **self.apply_chat_template_kwargs ) + else: + formatted_prompt = prompt - self.epoch_id = -1 - self.seed = seed - self.samples = self.origin_samples + multimodal_inputs = None + if self.processor: + from miles.utils.processing_utils import process_vision_info + + multimodal_inputs = process_vision_info(prompt, self.processor) + + sample = Sample( + prompt=formatted_prompt, + label=data.get(self.label_key) if self.label_key is not None else None, + metadata=metadata, + multimodal_inputs=multimodal_inputs, + ) + + return sample def shuffle(self, new_epoch_id): if self.epoch_id == new_epoch_id: return - random.seed(self.seed + new_epoch_id) - permutation = list(range(len(self.samples))) - random.shuffle(permutation) - self.samples = [self.origin_samples[i] for i in permutation] + logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}") + self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id) self.epoch_id = new_epoch_id - def __getitem__(self, idx): - return self.samples[idx] - - def __len__(self): - return len(self.samples) - def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): # use first fit to get the number of micro batches