-
Notifications
You must be signed in to change notification settings - Fork 81
perf: optimize micro-batch calculation #403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @ppraneth, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a critical performance optimization to the micro-batch calculation process, which is a known CPU bottleneck in large-scale Reinforcement Learning training. By transitioning from a quadratic time complexity First-Fit algorithm to a more efficient O(N log B) Max-Heap-based Best-Fit strategy, the change drastically reduces batch preparation times. This enhancement is crucial for improving training throughput, enabling better scaling of RL workflows, and freeing up CPU resources for other asynchronous tasks. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This is an excellent performance optimization. The pull request description clearly explains the problem with the previous O(N^2) implementation and the benefits of the new O(N log B) heap-based approach. The benchmark results effectively demonstrate the significant speedup. The code change is correct and implements the described logic well. I have a couple of minor suggestions to make the code a bit more concise.
|
Great work on identifying this CPU bottleneck! The transition from You are effectively switching from First-Fit to Largest-Fit (by always picking the bin with the most space). First-Fit tends to pack bins tighter, resulting in fewer total batches. Worst-Fit tends to distribute load, potentially resulting in more total batches. Since the function is named Please run your benchmark again and compare the output value ( |
|
The best way is to find the "Best-Fit", Find the bin that offers the tightest fit (i.e., the one with the least available space that is sufficient). |
|
To clarify the algorithmic choices, let's look at the three standard heuristics:
Among these, Option 3 is actually what we want. It packs bins the "tightest," producing the minimum number of total micro-batches. Option 2 (Largest-Fit) tends to distribute load evenly, which often results in more half-empty bins and higher GPU kernel overhead. |
I agree with @zhaochenyang20 and we talked about the implementation details:
Hope it helps! |
| return len(self.samples) | ||
|
|
||
|
|
||
| def get_minimum_num_micro_batch_size(total_lengths, max_tokens_per_gpu): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I think for this case, writing a new function will be better
|
@zhaochenyang20 @PopSoda2002 The major concern regarding increased batch counts is resolved. BFD generates ~2.1% fewer batches than the Previous (Worst-Fit) implementation. import time
import heapq
import random
import bisect
import numpy as np
from typing import List
# First-Fit (O(N^2))
def packing_original(total_lengths: List[int], limit: int):
bins = []
for length in total_lengths:
for i in range(len(bins)):
if bins[i] + length <= limit:
bins[i] += length
break
else:
bins.append(length)
return bins
# Largest-Fit/Worst-Fit (O(N log B))
def packing_worst_fit(total_lengths: List[int], limit: int):
remaining_capacities = []
for length in total_lengths:
if remaining_capacities and (-remaining_capacities[0] >= length):
most_space = -heapq.heappop(remaining_capacities)
heapq.heappush(remaining_capacities, -(most_space - length))
else:
heapq.heappush(remaining_capacities, -(limit - length))
# We return [limit - remaining] to get the filled tokens per bin
return [limit + res for res in remaining_capacities]
# Best-Fit Decreasing (O(N log N))
def packing_best_fit_decreasing(total_lengths: List[int], limit: int):
sorted_lengths = sorted(total_lengths, reverse=True)
bin_totals = []
for length in sorted_lengths:
threshold = limit - length
idx = bisect.bisect_right(bin_totals, threshold)
if idx > 0:
val = bin_totals.pop(idx - 1)
bisect.insort(bin_totals, val + length)
else:
bisect.insort(bin_totals, length)
return bin_totals
def run_incremental_benchmark():
LIMIT = 32768
sample_sizes = [5000, 10000, 20000, 30000, 40000]
methods = [
("Original (FF)", packing_original),
("Previous (WF)", packing_worst_fit),
("New (BFD)", packing_best_fit_decreasing)
]
print(f"{'N':<7} | {'Method':<15} | {'Time':<10} | {'Bins':<6} | {'Util %':<8} | {'SD'}")
print("-" * 65)
for n in sample_sizes:
lengths = [random.randint(512, 4096) for _ in range(n)]
for name, func in methods:
start = time.perf_counter()
bins = func(lengths, LIMIT)
duration = time.perf_counter() - start
# Metrics
count = len(bins)
utils = [(b / LIMIT) * 100 for b in bins]
avg_util = np.mean(utils)
std_dev = np.std(utils)
print(f"{n:<7} | {name:<15} | {duration:>8.4f}s | {count:>6} | {avg_util:>7.2f}% | {std_dev:>5.2f}")
print("-" * 65)
if __name__ == "__main__":
run_incremental_benchmark()Benchmark Result:
|
|
Can you once please check the updated code if it is ok |
|
I am worried for this analysis now:
So the improvement is pretty marginal compared to the initial version @ppraneth cc @zhaochenyang20 |
|
@PopSoda2002 You’re technically correct that maintaining a sorted list still results in O(N·B) theoretical complexity because That said, the practical performance improvement is meaningful for a few reasons: Interpreter overhead: The original implementation performs N×B comparisons in Python-level loops. The revised version moves the search and shifting to Benchmark data: In isolation, at N = 40,000 samples, the original version takes ~3.57s while the revised version takes ~0.07s (≈ 50× faster). At N = 250,000 samples, the original version takes ~159.36s and produces 17,684 bins, while the revised version takes ~2.43s and produces 17,613 bins (≈ 65× faster, 71 fewer bins). Downstream efficiency: By sorting first (Best-Fit Decreasing), we achieve tighter packing. In the benchmark this reduces the number of micro-batches (2815 vs 2826), which can translate to fewer forward/backward launches and slightly better GPU utilization. I agree this path is not currently the primary system bottleneck, and I’m not claiming a large end-to-end step-time reduction. The goal here is to reduce Python-level overhead and improve packing efficiency in micro-batch construction, while keeping this logic efficient and scalable as batch sizes grow. |
|
Time complexity:
Tradeoffs:
|
I think maybe you need to calculate the time complexity again for first fit, it should be O(N*B) not O(N^2), and usually we will not sample like 40000. The e2e optimization should be benchmarked. |
@PopSoda2002 I agree that While the revised BFD implementation technically shares this |
This PR optimizes the bin-packing logic in$O(N^2)$ to $O(N \log B)$ where $N$ is the number of samples and $B$ is the number of resulting micro-batches.
slime/utils/data.pyused to calculate the minimum number of micro-batches required for training. By replacing a linear search with a Max-Heap-based approach, we reduce the algorithmic complexity fromThe Problem
The current implementation of
get_minimum_num_micro_batch_sizeuses a First-Fit algorithm. For every sequence length in the batch, it performs a linear scan through all existing micro-batches to find the first one with available capacity:In large-scale RL training, where can reach tens of thousands of samples, the number of bins also grows. This nested loop leads to quadratic time complexity, causing a significant CPU bottleneck that stalls the GPU while the training process prepares the next batch.
The Solution
This PR implements a Best-Fit strategy using a Max-Heap to track the remaining capacity of each micro-batch.
Benchmark Results
The following results demonstrate that the optimization prevents the exponential performance degradation seen in the original version:
Impact