Skip to content
Open
74 changes: 45 additions & 29 deletions chatlearn/algorithm/grpo_utils/packing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import torch
import numpy as np
from sortedcontainers import SortedList

def bin_packing(seq_len_list: List[int], max_train_token: int):
"""
Expand Down Expand Up @@ -53,34 +54,44 @@ def bin_packing(seq_len_list: List[int], max_train_token: int):
bins_seqlen[best_bin_index].append(value)
return list(bins_id), list(bins_seqlen)

def bin_packing_fix_bin(seq_len_list: List[int], bin_size: int):
def adjust_bins(bins_id, bins_num):
"""
Implementation of best fit decreasing bin packing algorithm with fix bin size
Adjust bins to balance total sequence length.
First create a sorted list of (total_seq_len, -min_seq_len, sorted_bin_list).
This will make sure last element in sorted list is bin will largest total_seq_len
and smallest single sample sequence length.
For each round, pop the smallest bin from tail and add the smallest sample to head.
The loop will stop if moving sample will no longer increase balance.
"""
seqlen_id_mapping = dict(enumerate(seq_len_list))
sorted_mapping = dict(sorted(seqlen_id_mapping.items(), key=lambda item: item[1], reverse=True))
bins_id = [[] for i in range(bin_size)]
bins_seqlen = [[] for i in range(bin_size)]
for key, value in sorted_mapping.items():
min_sum = None
for id_, bin_ in enumerate(bins_seqlen):
bin_sum = value + sum(bin_)
if min_sum is None:
min_sum = bin_sum
best_bin_index = id_
else:
if bin_sum < min_sum:
min_sum = bin_sum
best_bin_index = id_
bins_id[best_bin_index].append(key)
bins_seqlen[best_bin_index].append(value)
# sort bins by seqlen in single bin
bins_seqlen_sum = [sum(bin_seqlen) for bin_seqlen in bins_seqlen]
sorted_bin = sorted(zip(bins_seqlen_sum, bins_id), reverse=True)
sorted_binseq = sorted(zip(bins_seqlen_sum, bins_seqlen), reverse=True)
_, bins_id = zip(*sorted_bin)
_, bins_seqlen = zip(*sorted_binseq)
return list(bins_id), list(bins_seqlen)
if len(bins_id) == 1:
return bins_id, bins_num
sorted_list = SortedList()
for i in range(len(bins_id)):
min_seq = bins_num[i][-1] if len(bins_num[i]) > 0 else 0
sorted_list.add((
sum(bins_num[i]),
-min_seq,
SortedList([(num, id_) for id_, num in zip(bins_id[i], bins_num[i])])))
# Balance sorted_list
stop = False
while not stop:
min_sum, _, min_bin = sorted_list.pop(0)
max_sum, _, max_bin = sorted_list.pop(-1)
smallest_num, smallest_id = max_bin.pop(0)
if abs((max_sum - min_sum - 2 * smallest_num)) < max_sum - min_sum:
min_bin.add((smallest_num, smallest_id))
sorted_list.add((max_sum - smallest_num, -max_bin[0][0], max_bin))
sorted_list.add((min_sum + smallest_num, -min_bin[0][0], min_bin))
else:
stop = True
max_bin.add((smallest_num, smallest_id))
sorted_list.add((max_sum, -max_bin[0][0], max_bin))
sorted_list.add((min_sum, -min_bin[0][0], min_bin))
bins_id = [[item[1] for item in list_[2]] for list_ in sorted_list]
bins_seq = [[item[0] for item in list_[2]] for list_ in sorted_list]
bins_id.reverse()
bins_seq.reverse()
return bins_id, bins_seq

def prepare_packing_attn_mask(total_seq_len_list: List[int], pad_size: int, dtype):
total_seq_length = sum(total_seq_len_list) + pad_size
Expand Down Expand Up @@ -128,16 +139,21 @@ def regroup_data_packing(
for data_b in data_list
]
# Get bin_packing result
bins_id, _ = bin_packing(seq_len_list=total_token_length, max_train_token=max_train_token)
bin_size = torch.tensor(len(bins_id)).cuda()
bins_id, bins_seq = bin_packing(seq_len_list=total_token_length, max_train_token=max_train_token)
local_bin_size = len(bins_id)
bin_size = torch.tensor(local_bin_size).cuda()
# Get max_bin_size across all rank in same model replica
# For megatron, all_reduce along mp group first and emp group second
# For FSDP, all_reduce along default group
if process_group_list is None:
process_group_list = [None]
for pg in process_group_list:
torch.distributed.all_reduce(bin_size, op=torch.distributed.ReduceOp.MAX, group=pg)
bins_id, _ = bin_packing_fix_bin(seq_len_list=total_token_length, bin_size=bin_size.cpu().item())
max_bin_size = bin_size.cpu().item()
for _ in range(max_bin_size - local_bin_size):
bins_id.append([])
bins_seq.append([])
bins_id, bins_seq = adjust_bins(bins_id, bins_seq)
# Prepare train data for each micro batch
for micro_batch_id in bins_id:
regroup_data_list.append([])
Expand Down