Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions miles/rollout/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, args):
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
seed=args.rollout_seed,
dataset_num_proc=args.dataset_num_proc,
)
if self.args.rollout_shuffle:
self.dataset.shuffle(self.epoch_id)
Expand All @@ -85,15 +86,15 @@ def get_samples(self, num_samples):
# TODO further improve code
if self.dataset is not None:
if self.sample_offset + num_samples <= len(self.dataset):
prompt_samples = self.dataset.samples[self.sample_offset : self.sample_offset + num_samples]
prompt_samples = [self.dataset[i] for i in range(self.sample_offset, self.sample_offset + num_samples)]
self.sample_offset += num_samples
else:
prompt_samples = self.dataset.samples[self.sample_offset :]
prompt_samples = [self.dataset[i] for i in range(self.sample_offset, len(self.dataset))]
num_samples -= len(prompt_samples)
self.epoch_id += 1
if self.args.rollout_shuffle:
self.dataset.shuffle(self.epoch_id)
prompt_samples += self.dataset.samples[:num_samples]
prompt_samples += [self.dataset[i] for i in range(num_samples)]
self.sample_offset = num_samples
else:
prompt_samples = [Sample() for _ in range(num_samples)]
Expand Down
4 changes: 3 additions & 1 deletion miles/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,7 @@ async def eval_rollout_single_dataset(
tool_key=dataset_cfg.tool_key,
apply_chat_template=args.apply_chat_template,
apply_chat_template_kwargs=args.apply_chat_template_kwargs,
dataset_num_proc=args.dataset_num_proc,
)
dataset = EVAL_PROMPT_DATASET[cache_key]

Expand All @@ -527,7 +528,8 @@ async def eval_rollout_single_dataset(
tasks = []
# do multiple samples for eval prompts
sample_index = 0
for _i, prompt_sample in enumerate(dataset.samples):
for i in range(len(dataset)):
prompt_sample = dataset[i]
for j in range(dataset_cfg.n_samples_per_eval_prompt):
# use the same prompt for multiple samples
sample = copy.deepcopy(prompt_sample)
Expand Down
6 changes: 6 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,12 @@ def add_data_arguments(parser):
"and should be set to a larger value than `max_tokens_per_gpu` if you want better performance. "
),
)
parser.add_argument(
"--dataset-num-proc",
type=int,
default=8,
help="Number of processes for dataset initialization and filtering.",
)
return parser

def add_eval_arguments(parser):
Expand Down
220 changes: 162 additions & 58 deletions miles/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import json
import logging
import os
import random
import re
from functools import partial

import datasets
import numpy as np
import ray

Expand All @@ -21,6 +22,55 @@

logger = logging.getLogger(__name__)

_FILE_TYPE_MAP = {
".jsonl": "json",
".parquet": "parquet",
}


def _filter_func(
example,
tokenizer,
processor,
max_length,
prompt_key,
multimodal_keys,
apply_chat_template,
apply_chat_template_kwargs,
tool_key,
):
as_conversation = apply_chat_template
prompt = _build_messages(example, prompt_key, as_conversation, multimodal_keys)

tools = None
if tool_key is not None and tool_key in example:
tools = example[tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"

if apply_chat_template:
formatted_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
else:
formatted_prompt = prompt

if processor:
from miles.utils.processing_utils import process_vision_info

multimodal_inputs = process_vision_info(prompt, processor)
else:
multimodal_inputs = None

return not _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs)


def read_file(path):
path, row_slice = _parse_generalized_path(path)
Expand Down Expand Up @@ -161,77 +211,131 @@ def __init__(
seed=42,
apply_chat_template=False,
apply_chat_template_kwargs=None,
dataset_num_proc=8,
):
self.origin_samples = []
for data in read_file(path):
as_conversation = apply_chat_template
prompt = _build_messages(data, prompt_key, as_conversation, multimodal_keys)

metadata = data.get(metadata_key) or {}
tools = None
if tool_key is not None and tool_key in data:
tools = data[tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

if apply_chat_template:
formatted_prompt = tokenizer.apply_chat_template(
prompt,
tools=tools,
tokenize=False,
add_generation_prompt=True,
**(apply_chat_template_kwargs or {}),
)
else:
formatted_prompt = prompt
# 1. Store basic config
self.tokenizer = tokenizer
self.processor = processor
self.max_length = max_length
self.prompt_key = prompt_key
self.multimodal_keys = multimodal_keys
self.label_key = label_key
self.tool_key = tool_key
self.metadata_key = metadata_key
self.apply_chat_template = apply_chat_template
self.apply_chat_template_kwargs = apply_chat_template_kwargs or {}
self.seed = seed
self.epoch_id = -1

if processor:
from miles.utils.processing_utils import process_vision_info
# 2. Load and process dataset
self.hf_dataset = self._load_and_filter_dataset(path, dataset_num_proc)
self.origin_hf_dataset = self.hf_dataset

assert isinstance(
prompt, list
), f"prompt must be a list when processor is not None, got {type(prompt)} instead"
multimodal_inputs = process_vision_info(prompt, processor)
else:
multimodal_inputs = None
def _get_file_type(self, path: str) -> str:
_, ext = os.path.splitext(path)

# TODO: this is slow.
if _should_skip_prompt(formatted_prompt, tokenizer, processor, max_length, multimodal_inputs):
continue
try:
return _FILE_TYPE_MAP[ext]
except KeyError:
raise ValueError(f"Unsupported format: {ext}. Supported: {list(_FILE_TYPE_MAP.keys())}") from None

self.origin_samples.append(
Sample(
prompt=formatted_prompt,
label=data[label_key] if label_key is not None else None,
metadata=metadata,
multimodal_inputs=multimodal_inputs,
)
def _load_and_filter_dataset(self, path, dataset_num_proc):
raw_file_path, row_slice = _parse_generalized_path(path)

if not os.path.exists(raw_file_path):
raise FileNotFoundError(f"Prompt dataset path '{raw_file_path}' does not exist.")

logger.info(f"Loading dataset from {raw_file_path} using Hugging Face datasets.")

# Determine file type and load using datasets library for memory-mapped access
file_type = self._get_file_type(raw_file_path)
ds = datasets.load_dataset(file_type, data_files=raw_file_path, split="train")

# Apply row slicing if specified
if row_slice:
num_rows = len(ds)
indices = range(num_rows)[row_slice]
ds = ds.select(indices)
logger.info(f"Applied slice {row_slice}, dataset size: {len(ds)}")

filter_kwargs = {
"tokenizer": self.tokenizer,
"processor": self.processor,
"max_length": self.max_length,
"prompt_key": self.prompt_key,
"multimodal_keys": self.multimodal_keys,
"apply_chat_template": self.apply_chat_template,
"apply_chat_template_kwargs": self.apply_chat_template_kwargs,
"tool_key": self.tool_key,
}

original_size = len(ds)

ds = ds.filter(
partial(_filter_func, **filter_kwargs), num_proc=dataset_num_proc, desc="Filtering invalid samples"
)

new_size = len(ds)
logger.info(f"Filtered dataset from {original_size} to {new_size} samples.")

return ds

def __len__(self):
return len(self.hf_dataset)

def __getitem__(self, idx):
# The underlying HF dataset handles lazy fetching
data = self.hf_dataset[idx]

# Process the data using existing logic
as_conversation = self.apply_chat_template
prompt = _build_messages(data, self.prompt_key, as_conversation, self.multimodal_keys)

metadata = data.get(self.metadata_key) or {}
tools = None
if self.tool_key is not None and self.tool_key in data:
tools = data[self.tool_key]
if isinstance(tools, str):
tools = json.loads(tools)
# TODO (chenyang): If the JSON parsing is heavy, we might need
# to use hf_dataset.map() during init to pre-process these
# fields into a more efficient format (Arrow-native), rather
# than parsing raw strings on the fly.
elif isinstance(tools, np.ndarray):
tools = tools.tolist()
assert isinstance(tools, list), f"tools must be a list, got {type(tools)} instead"
metadata["tools"] = tools

if self.apply_chat_template:
formatted_prompt = self.tokenizer.apply_chat_template(
prompt, tools=tools, tokenize=False, add_generation_prompt=True, **self.apply_chat_template_kwargs
)
else:
formatted_prompt = prompt

self.epoch_id = -1
self.seed = seed
self.samples = self.origin_samples
multimodal_inputs = None
if self.processor:
from miles.utils.processing_utils import process_vision_info

multimodal_inputs = process_vision_info(prompt, self.processor)

sample = Sample(
prompt=formatted_prompt,
label=data.get(self.label_key) if self.label_key is not None else None,
metadata=metadata,
multimodal_inputs=multimodal_inputs,
)

return sample

def shuffle(self, new_epoch_id):
if self.epoch_id == new_epoch_id:
return

random.seed(self.seed + new_epoch_id)
permutation = list(range(len(self.samples)))
random.shuffle(permutation)
self.samples = [self.origin_samples[i] for i in permutation]
logger.info(f"Shuffling dataset for epoch {new_epoch_id} with seed {self.seed + new_epoch_id}")
self.hf_dataset = self.origin_hf_dataset.shuffle(seed=self.seed + new_epoch_id)
self.epoch_id = new_epoch_id

def __getitem__(self, idx):
return self.samples[idx]

def __len__(self):
return len(self.samples)


def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu):
# use first fit to get the number of micro batches
Expand Down