diff --git a/miles/utils/data.py b/miles/utils/data.py index 6e64ef678..f9f5235c3 100644 --- a/miles/utils/data.py +++ b/miles/utils/data.py @@ -1,3 +1,4 @@ +import bisect import itertools import json import logging @@ -256,17 +257,31 @@ def __len__(self): def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): - # use first fit to get the number of micro batches - batches = [] - for length in total_lengths: - for i in range(len(batches)): - if batches[i] + length <= max_tokens_per_gpu: - batches[i] += length - break + + # Sort lengths in descending order . + sorted_lengths = sorted(total_lengths, reverse=True) + + # Maintain a sorted list of current bin totals (filled capacities). + # This allows us to use binary search (bisect) to find the 'Best-Fit' in O(log B). + bin_totals = [] + + for length in sorted_lengths: + # The 'Best-Fit' bin is the one with the smallest remaining space that still fits. + # Mathematically, we want the bin with the largest filled capacity <= (limit - length). + threshold = max_tokens_per_gpu - length + + # Binary search for the best bin candidate. + idx = bisect.bisect_right(bin_totals, threshold) + + if idx > 0: + # Pop and re-insert to maintain the sorted order of bin_totals. + current_fill = bin_totals.pop(idx - 1) + bisect.insort(bin_totals, current_fill + length) else: - batches.append(length) + # No existing bin fits the current sequence; create a new bin. + bisect.insort(bin_totals, length) - return len(batches) + return len(bin_totals) def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):