Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions miles/utils/data.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bisect
import itertools
import json
import logging
Expand Down Expand Up @@ -256,17 +257,31 @@ def __len__(self):


def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu):
# use first fit to get the number of micro batches
batches = []
for length in total_lengths:
for i in range(len(batches)):
if batches[i] + length <= max_tokens_per_gpu:
batches[i] += length
break

# Sort lengths in descending order .
sorted_lengths = sorted(total_lengths, reverse=True)

# Maintain a sorted list of current bin totals (filled capacities).
# This allows us to use binary search (bisect) to find the 'Best-Fit' in O(log B).
bin_totals = []

for length in sorted_lengths:
# The 'Best-Fit' bin is the one with the smallest remaining space that still fits.
# Mathematically, we want the bin with the largest filled capacity <= (limit - length).
threshold = max_tokens_per_gpu - length

# Binary search for the best bin candidate.
idx = bisect.bisect_right(bin_totals, threshold)

if idx > 0:
# Pop and re-insert to maintain the sorted order of bin_totals.
current_fill = bin_totals.pop(idx - 1)
bisect.insort(bin_totals, current_fill + length)
else:
batches.append(length)
# No existing bin fits the current sequence; create a new bin.
bisect.insort(bin_totals, length)

return len(batches)
return len(bin_totals)


def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
Expand Down