diff --git a/chatlearn/algorithm/grpo_utils/packing_utils.py b/chatlearn/algorithm/grpo_utils/packing_utils.py index f6b75934..34ee9d1d 100644 --- a/chatlearn/algorithm/grpo_utils/packing_utils.py +++ b/chatlearn/algorithm/grpo_utils/packing_utils.py @@ -19,6 +19,7 @@ import torch import numpy as np +from sortedcontainers import SortedList def bin_packing(seq_len_list: List[int], max_train_token: int): """ @@ -53,34 +54,44 @@ def bin_packing(seq_len_list: List[int], max_train_token: int): bins_seqlen[best_bin_index].append(value) return list(bins_id), list(bins_seqlen) -def bin_packing_fix_bin(seq_len_list: List[int], bin_size: int): +def adjust_bins(bins_id, bins_num): """ - Implementation of best fit decreasing bin packing algorithm with fix bin size + Adjust bins to balance total sequence length. + First create a sorted list of (total_seq_len, -min_seq_len, sorted_bin_list). + This will make sure last element in sorted list is bin will largest total_seq_len + and smallest single sample sequence length. + For each round, pop the smallest bin from tail and add the smallest sample to head. + The loop will stop if moving sample will no longer increase balance. """ - seqlen_id_mapping = dict(enumerate(seq_len_list)) - sorted_mapping = dict(sorted(seqlen_id_mapping.items(), key=lambda item: item[1], reverse=True)) - bins_id = [[] for i in range(bin_size)] - bins_seqlen = [[] for i in range(bin_size)] - for key, value in sorted_mapping.items(): - min_sum = None - for id_, bin_ in enumerate(bins_seqlen): - bin_sum = value + sum(bin_) - if min_sum is None: - min_sum = bin_sum - best_bin_index = id_ - else: - if bin_sum < min_sum: - min_sum = bin_sum - best_bin_index = id_ - bins_id[best_bin_index].append(key) - bins_seqlen[best_bin_index].append(value) - # sort bins by seqlen in single bin - bins_seqlen_sum = [sum(bin_seqlen) for bin_seqlen in bins_seqlen] - sorted_bin = sorted(zip(bins_seqlen_sum, bins_id), reverse=True) - sorted_binseq = sorted(zip(bins_seqlen_sum, bins_seqlen), reverse=True) - _, bins_id = zip(*sorted_bin) - _, bins_seqlen = zip(*sorted_binseq) - return list(bins_id), list(bins_seqlen) + if len(bins_id) == 1: + return bins_id, bins_num + sorted_list = SortedList() + for i in range(len(bins_id)): + min_seq = bins_num[i][-1] if len(bins_num[i]) > 0 else 0 + sorted_list.add(( + sum(bins_num[i]), + -min_seq, + SortedList([(num, id_) for id_, num in zip(bins_id[i], bins_num[i])]))) + # Balance sorted_list + stop = False + while not stop: + min_sum, _, min_bin = sorted_list.pop(0) + max_sum, _, max_bin = sorted_list.pop(-1) + smallest_num, smallest_id = max_bin.pop(0) + if abs((max_sum - min_sum - 2 * smallest_num)) < max_sum - min_sum: + min_bin.add((smallest_num, smallest_id)) + sorted_list.add((max_sum - smallest_num, -max_bin[0][0], max_bin)) + sorted_list.add((min_sum + smallest_num, -min_bin[0][0], min_bin)) + else: + stop = True + max_bin.add((smallest_num, smallest_id)) + sorted_list.add((max_sum, -max_bin[0][0], max_bin)) + sorted_list.add((min_sum, -min_bin[0][0], min_bin)) + bins_id = [[item[1] for item in list_[2]] for list_ in sorted_list] + bins_seq = [[item[0] for item in list_[2]] for list_ in sorted_list] + bins_id.reverse() + bins_seq.reverse() + return bins_id, bins_seq def prepare_packing_attn_mask(total_seq_len_list: List[int], pad_size: int, dtype): total_seq_length = sum(total_seq_len_list) + pad_size @@ -128,8 +139,9 @@ def regroup_data_packing( for data_b in data_list ] # Get bin_packing result - bins_id, _ = bin_packing(seq_len_list=total_token_length, max_train_token=max_train_token) - bin_size = torch.tensor(len(bins_id)).cuda() + bins_id, bins_seq = bin_packing(seq_len_list=total_token_length, max_train_token=max_train_token) + local_bin_size = len(bins_id) + bin_size = torch.tensor(local_bin_size).cuda() # Get max_bin_size across all rank in same model replica # For megatron, all_reduce along mp group first and emp group second # For FSDP, all_reduce along default group @@ -137,7 +149,11 @@ def regroup_data_packing( process_group_list = [None] for pg in process_group_list: torch.distributed.all_reduce(bin_size, op=torch.distributed.ReduceOp.MAX, group=pg) - bins_id, _ = bin_packing_fix_bin(seq_len_list=total_token_length, bin_size=bin_size.cpu().item()) + max_bin_size = bin_size.cpu().item() + for _ in range(max_bin_size - local_bin_size): + bins_id.append([]) + bins_seq.append([]) + bins_id, bins_seq = adjust_bins(bins_id, bins_seq) # Prepare train data for each micro batch for micro_batch_id in bins_id: regroup_data_list.append([])