From ff3349299eb26e92cb81278388a269ba65b2cbc4 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Thu, 10 Jul 2025 17:00:09 +0800 Subject: [PATCH 01/10] Force forward_step and train_step have same data order --- .../algorithm/grpo_utils/advantage_compute.py | 9 +++++++-- chatlearn/runtime/trainer.py | 17 +++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/chatlearn/algorithm/grpo_utils/advantage_compute.py b/chatlearn/algorithm/grpo_utils/advantage_compute.py index 13da2d64..9e62ede8 100644 --- a/chatlearn/algorithm/grpo_utils/advantage_compute.py +++ b/chatlearn/algorithm/grpo_utils/advantage_compute.py @@ -7,8 +7,11 @@ def compute_grpo_adv(episode_replay_buffers): buffers = episode_replay_buffers[-1].buffer queryids2samples = defaultdict(list) + sample_id = 0 for s in buffers: + s['sample_id'] = sample_id queryids2samples[hash(",".join(map(str, s["prompt_token_ids"])))].append(s) + sample_id += 1 res_buffers = [] for _, l in queryids2samples.items(): @@ -20,6 +23,8 @@ def compute_grpo_adv(episode_replay_buffers): for i, li in enumerate(l): li["advantages"] = (rewards[i] - mean) / (std + 1e-5) res_buffers.extend(l) - # Shuffle train sample to balance training workload across dp_ranks - random.shuffle(res_buffers) + + # Sort samples by original order in buffer + res_buffers.sort(key=lambda x: x["sample_id"]) + return res_buffers diff --git a/chatlearn/runtime/trainer.py b/chatlearn/runtime/trainer.py index 41c975f9..ba377a9b 100644 --- a/chatlearn/runtime/trainer.py +++ b/chatlearn/runtime/trainer.py @@ -99,10 +99,19 @@ def train(self, episode: int): future.wait(self._data_loader.shuffle.remote()) data_queues, out_queue = self.setup_queues() - for mb in range(_num_training_iteration * self.data_parallel_size): - batch = encode_data(mb, self.next_batch()) - for data_queue in data_queues: - data_queue.put(batch) + batch_list = [] + # get batch by iterate over dp first, iteration second + for dp_rank in range(self.data_parallel_size): + for iter_ in range(_num_training_iteration): + batch = encode_data(iter_ * self.data_parallel_size + dp_rank, self.next_batch()) + batch_list.append(batch) + + # put batch by iterate over iteration first, dp second + for iter_ in range(_num_training_iteration): + for dp_rank in range(self.data_parallel_size): + for data_queue in data_queues: + data_queue.put(batch_list[dp_rank * _num_training_iteration + iter_]) + self.compute_loop(out_queue, _num_training_iteration) self.iteration = self.iteration + _num_training_iteration logger.info(f"train episode: {episode+1}, epoch {epoch} num_step {_num_training_iteration} done") From 6c21df2c4df43736ddce9a46cd09ae5e05a28a5d Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Thu, 10 Jul 2025 20:59:23 +0800 Subject: [PATCH 02/10] Remove key after sort; code refine --- .../algorithm/grpo_utils/advantage_compute.py | 3 +- chatlearn/runtime/executor.py | 33 +++++++------------ 2 files changed, 13 insertions(+), 23 deletions(-) diff --git a/chatlearn/algorithm/grpo_utils/advantage_compute.py b/chatlearn/algorithm/grpo_utils/advantage_compute.py index 9e62ede8..26a812dc 100644 --- a/chatlearn/algorithm/grpo_utils/advantage_compute.py +++ b/chatlearn/algorithm/grpo_utils/advantage_compute.py @@ -26,5 +26,6 @@ def compute_grpo_adv(episode_replay_buffers): # Sort samples by original order in buffer res_buffers.sort(key=lambda x: x["sample_id"]) - + for data in res_buffers: + data.pop("sample_id") return res_buffers diff --git a/chatlearn/runtime/executor.py b/chatlearn/runtime/executor.py index b7c3c866..83e79461 100644 --- a/chatlearn/runtime/executor.py +++ b/chatlearn/runtime/executor.py @@ -297,34 +297,23 @@ def generate_step_one_model_internal(self, model_node, in_queue, step_num, repli replica_num = len(model.replicas) output = [] + last_step_start = max(self.num_iteration(model) - replica_num, 0) + is_last_batch = step_num >= last_step_start + kwargs["is_last_batch"] = is_last_batch + if is_eval is not None: + kwargs["is_eval"] = is_eval + if to_empty_cache is not None: + kwargs["to_empty_cache"] = to_empty_cache + if to_onload is not None: + kwargs["to_onload"] = to_onload + if to_offload is not None: + kwargs["to_offload"] = to_offload if isinstance(replica.model, VLLMModule): - last_step_start = max(self.num_iteration(model) - replica_num, 0) - is_last_batch = step_num >= last_step_start - kwargs["is_last_batch"] = is_last_batch - if is_eval is not None: - kwargs["is_eval"] = is_eval - if to_empty_cache is not None: - kwargs["to_empty_cache"] = to_empty_cache - if to_onload is not None: - kwargs["to_onload"] = to_onload - if to_offload is not None: - kwargs["to_offload"] = to_offload mb, query = self.get_next_data(in_queue, model_node, micro_batch_index) assert isinstance(query, list) ret = replica.call_actor_remote_func(replica.vllm_engine, func_name, *query, **kwargs) output.append((ret, mb)) else: - last_step_start = max(self.num_iteration(model) - replica_num, 0) - is_last_batch = step_num >= last_step_start - kwargs["is_last_batch"] = is_last_batch - if to_empty_cache is not None: - kwargs["to_empty_cache"] = to_empty_cache - if to_onload is not None: - kwargs["to_onload"] = to_onload - if to_offload is not None: - kwargs["to_offload"] = to_offload - if is_eval is not None: - kwargs["is_eval"] = is_eval for _, actors in replica.dp_rank_to_actors.items(): mb, query = self.get_next_data(in_queue, model_node, micro_batch_index) assert isinstance(query, list) From 6bd1172b70dc7b8df9ba461118c271e335694985 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Thu, 10 Jul 2025 21:05:36 +0800 Subject: [PATCH 03/10] Fix pylint --- chatlearn/algorithm/grpo_utils/advantage_compute.py | 1 - chatlearn/runtime/trainer.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/chatlearn/algorithm/grpo_utils/advantage_compute.py b/chatlearn/algorithm/grpo_utils/advantage_compute.py index 26a812dc..0f04b298 100644 --- a/chatlearn/algorithm/grpo_utils/advantage_compute.py +++ b/chatlearn/algorithm/grpo_utils/advantage_compute.py @@ -1,6 +1,5 @@ """compute advantage for grpo""" from collections import defaultdict -import random import torch diff --git a/chatlearn/runtime/trainer.py b/chatlearn/runtime/trainer.py index ba377a9b..1b879523 100644 --- a/chatlearn/runtime/trainer.py +++ b/chatlearn/runtime/trainer.py @@ -100,13 +100,14 @@ def train(self, episode: int): data_queues, out_queue = self.setup_queues() batch_list = [] - # get batch by iterate over dp first, iteration second + # Data will merge by dp_rank order after environment.make_experiences() + # For off-policy, we need to get batch by iterate over dp first, iteration second for dp_rank in range(self.data_parallel_size): for iter_ in range(_num_training_iteration): batch = encode_data(iter_ * self.data_parallel_size + dp_rank, self.next_batch()) batch_list.append(batch) - - # put batch by iterate over iteration first, dp second + + # After get all batches, put batch into trainer's input queue by iterate over iteration first, dp second for iter_ in range(_num_training_iteration): for dp_rank in range(self.data_parallel_size): for data_queue in data_queues: From 87882335085d2accbc38541a673be0a973497511 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Fri, 11 Jul 2025 14:58:19 +0800 Subject: [PATCH 04/10] Fix train script --- scripts/train_fsdp_vllm_qwen3_235b_a22b_grpo.sh | 2 +- scripts/train_fsdp_vllm_qwen3_30b_a3b_grpo.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/train_fsdp_vllm_qwen3_235b_a22b_grpo.sh b/scripts/train_fsdp_vllm_qwen3_235b_a22b_grpo.sh index 251271a1..f60818d7 100644 --- a/scripts/train_fsdp_vllm_qwen3_235b_a22b_grpo.sh +++ b/scripts/train_fsdp_vllm_qwen3_235b_a22b_grpo.sh @@ -13,7 +13,7 @@ python chatlearn/entrypoint.py grpo \ --config-file template/grpo_fsdp.yaml \ runtime_args.exp_name=${exp_name} \ runtime_args.data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json \ - runtime_args.eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json \ + runtime_args.eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/test.json \ runtime_args.output_dir=${CHATLEARN}/output/${exp_name} \ runtime_args.num_episode=200 \ runtime_args.sample_per_episode=2048 \ diff --git a/scripts/train_fsdp_vllm_qwen3_30b_a3b_grpo.sh b/scripts/train_fsdp_vllm_qwen3_30b_a3b_grpo.sh index 646a1a80..421b5743 100644 --- a/scripts/train_fsdp_vllm_qwen3_30b_a3b_grpo.sh +++ b/scripts/train_fsdp_vllm_qwen3_30b_a3b_grpo.sh @@ -13,7 +13,7 @@ python chatlearn/entrypoint.py grpo \ --config-file template/grpo_fsdp.yaml \ runtime_args.exp_name=${exp_name} \ runtime_args.data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json \ - runtime_args.eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/train.json \ + runtime_args.eval_data_path=${CHATLEARN}/dataset/MATH-lighteval/test.json \ runtime_args.output_dir=${CHATLEARN}/output/${exp_name} \ runtime_args.num_episode=200 \ runtime_args.sample_per_episode=512 \ From 840e0a45a6754daf409d3489d565e55ab40a7d05 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Fri, 11 Jul 2025 17:24:12 +0800 Subject: [PATCH 05/10] Support offline fsdp ckpt merge --- chatlearn/offline_ckpt_converter.py | 112 +++++++++++++++++++++++++++ scripts/fsdp_weight_converter_cpu.sh | 7 ++ 2 files changed, 119 insertions(+) create mode 100644 chatlearn/offline_ckpt_converter.py create mode 100755 scripts/fsdp_weight_converter_cpu.sh diff --git a/chatlearn/offline_ckpt_converter.py b/chatlearn/offline_ckpt_converter.py new file mode 100644 index 00000000..e6dee24b --- /dev/null +++ b/chatlearn/offline_ckpt_converter.py @@ -0,0 +1,112 @@ +import os +import json +import glob +import os +import shutil +import argparse +from typing import defaultdict, List + +import torch +import torch.distributed as dist +import torch.distributed.tensor +from concurrent.futures import ProcessPoolExecutor + + +from safetensors.torch import save_file + +from tqdm import tqdm + +def save_safetensor_item(safetensor_name, safetensor_data, save_dir): + save_file(safetensor_data, os.path.join(save_dir, safetensor_name)) + +def split_list(lst, n): + """Split list into n roughly equal chunks.""" + k, m = divmod(len(lst), n) + return [lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)] + +def qwen3_key_mapping(param, model_config): + part = param.split('.')[-1] + num_expert = model_config['num_experts'] + local_names = [param.replace('group_mlp', f"experts.{i}") for i in range(num_expert)] + return local_names + +def check_groupgemm_param(key): + return 'group_mlp' in key + +arch_mapping = { + 'Qwen3MoeForCausalLM': (check_groupgemm_param, qwen3_key_mapping) +} + +def convert_checkpoint_cpu(args): + iter_ = args.iter + hf_dir = args.hf_dir + dist_model_dir = os.path.join(args.ckpt_dir, str(iter_)) + save_dir = args.save_dir + os.makedirs(save_dir, exist_ok=True) + is_groupgemm = args.groupgemm + + safetensor_file = [] + other_file = [] + for file in os.listdir(hf_dir): + if file.endswith(".safetensors"): + safetensor_file.append(file) + else: + other_file.append(file) + + safetensor_config = json.load(open(os.path.join(hf_dir, "model.safetensors.index.json"))) + weight_map = safetensor_config['weight_map'] + model_config = json.load(open(os.path.join(hf_dir, "config.json"))) + arch = model_config['architectures'][0] + + # Check whether training using groupgemm + group_gemm = arch in arch_mapping and is_groupgemm + if group_gemm: + check_param_fn, key_convert_fn = arch_mapping[arch] + num_expert = model_config['num_experts'] + + dist_model_files = glob.glob(os.path.join(dist_model_dir, "model_world_size_*.pt")) + dist_model_state_dict = [] + for file in tqdm(dist_model_files, "Read Distributed Checkpoints"): + dist_model_state_dict.append(torch.load(file, map_location="cpu")) + param_list = list(dist_model_state_dict[0].keys()) + safetensor_dict = {key: {} for key in safetensor_file} + + for param in tqdm(param_list, desc="Merge Weights"): + global_tensor = torch.cat([state_dict.pop(param).to_local() for state_dict in dist_model_state_dict], dim=0) + if group_gemm: + if check_param_fn(param): + # Split param for groupgemm mlp weights + local_names = key_convert_fn(param, model_config) + num_expert = len(local_names) + global_tensor = torch.chunk(global_tensor, num_expert, dim=0) + safetensor_name = weight_map[local_names[0]] + for i in range(num_expert): + safetensor_dict[safetensor_name][local_names[i]] = global_tensor[i] + else: + safetensor_name = weight_map[param] + safetensor_dict[safetensor_name][param] = global_tensor + else: + safetensor_name = weight_map[param] + safetensor_dict[safetensor_name][param] = global_tensor + + # Save safetensor files + for name in tqdm(safetensor_dict, "Save Safetensor Files"): + save_safetensor_item(name, safetensor_dict[name], save_dir) + + # Copy other files + for file in other_file: + if not file.startswith('.'): + src_file = os.path.join(hf_dir, file) + dst_file = os.path.join(save_dir, file) + shutil.copy(src_file, dst_file) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Offline Checkpoint Converter") + parser.add_argument("--hf_dir", type=str, required=True, help="Directory to load hf config") + parser.add_argument("--ckpt_dir", type=str, required=True, help="Directory to load sharded checkpoint") + parser.add_argument("--save_dir", type=str, required=True, help="Directory to save converted hf checkpoint") + parser.add_argument("--groupgemm", action='store_true', help="Whether use groupgemm for training") + parser.add_argument("--iter", type=int, required=True, help="which iter to convert") + args = parser.parse_args() + convert_checkpoint_cpu(args) \ No newline at end of file diff --git a/scripts/fsdp_weight_converter_cpu.sh b/scripts/fsdp_weight_converter_cpu.sh new file mode 100755 index 00000000..a620dda5 --- /dev/null +++ b/scripts/fsdp_weight_converter_cpu.sh @@ -0,0 +1,7 @@ +export CHATLEARN=$(pwd) +python chatlearn/offline_ckpt_converter.py \ + --hf_dir ${CHATLEARN}/Qwen3-235B-A22B/ \ + --ckpt_dir ${CHATLEARN}/output/qwen3-grpo-235b-a22b/save_model/policy_trainer \ + --save_dir ${CHATLEARN}/output/qwen3-grpo-8b/save_model/huggingface/ \ + --iter 200 \ + --groupgemm \ No newline at end of file From 4c023af4d39fbcdd05d304535874d3f8e9e2cd01 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Fri, 11 Jul 2025 17:59:44 +0800 Subject: [PATCH 06/10] fix pylint --- chatlearn/offline_ckpt_converter.py | 31 +++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/chatlearn/offline_ckpt_converter.py b/chatlearn/offline_ckpt_converter.py index e6dee24b..e547008e 100644 --- a/chatlearn/offline_ckpt_converter.py +++ b/chatlearn/offline_ckpt_converter.py @@ -1,16 +1,26 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Offline FSDP checkpoint merge""" + import os import json import glob -import os import shutil import argparse -from typing import defaultdict, List import torch -import torch.distributed as dist -import torch.distributed.tensor -from concurrent.futures import ProcessPoolExecutor - from safetensors.torch import save_file @@ -25,7 +35,6 @@ def split_list(lst, n): return [lst[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n)] def qwen3_key_mapping(param, model_config): - part = param.split('.')[-1] num_expert = model_config['num_experts'] local_names = [param.replace('group_mlp', f"experts.{i}") for i in range(num_expert)] return local_names @@ -63,7 +72,7 @@ def convert_checkpoint_cpu(args): if group_gemm: check_param_fn, key_convert_fn = arch_mapping[arch] num_expert = model_config['num_experts'] - + dist_model_files = glob.glob(os.path.join(dist_model_dir, "model_world_size_*.pt")) dist_model_state_dict = [] for file in tqdm(dist_model_files, "Read Distributed Checkpoints"): @@ -72,7 +81,7 @@ def convert_checkpoint_cpu(args): safetensor_dict = {key: {} for key in safetensor_file} for param in tqdm(param_list, desc="Merge Weights"): - global_tensor = torch.cat([state_dict.pop(param).to_local() for state_dict in dist_model_state_dict], dim=0) + global_tensor = torch.cat([state_dict.pop(param).to_local() for state_dict in dist_model_state_dict], dim=0) if group_gemm: if check_param_fn(param): # Split param for groupgemm mlp weights @@ -92,7 +101,7 @@ def convert_checkpoint_cpu(args): # Save safetensor files for name in tqdm(safetensor_dict, "Save Safetensor Files"): save_safetensor_item(name, safetensor_dict[name], save_dir) - + # Copy other files for file in other_file: if not file.startswith('.'): @@ -109,4 +118,4 @@ def convert_checkpoint_cpu(args): parser.add_argument("--groupgemm", action='store_true', help="Whether use groupgemm for training") parser.add_argument("--iter", type=int, required=True, help="which iter to convert") args = parser.parse_args() - convert_checkpoint_cpu(args) \ No newline at end of file + convert_checkpoint_cpu(args) From 483af9927577b869544462d114bf910a8d9858f1 Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 14 Jul 2025 13:46:54 +0800 Subject: [PATCH 07/10] fix pylint --- chatlearn/offline_ckpt_converter.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/chatlearn/offline_ckpt_converter.py b/chatlearn/offline_ckpt_converter.py index e547008e..d0ae0448 100644 --- a/chatlearn/offline_ckpt_converter.py +++ b/chatlearn/offline_ckpt_converter.py @@ -21,6 +21,7 @@ import argparse import torch +import torch.distributed.tensor from safetensors.torch import save_file @@ -46,7 +47,7 @@ def check_groupgemm_param(key): 'Qwen3MoeForCausalLM': (check_groupgemm_param, qwen3_key_mapping) } -def convert_checkpoint_cpu(args): +def convert_checkpoint_cpu(args_input): iter_ = args.iter hf_dir = args.hf_dir dist_model_dir = os.path.join(args.ckpt_dir, str(iter_)) @@ -62,9 +63,11 @@ def convert_checkpoint_cpu(args): else: other_file.append(file) - safetensor_config = json.load(open(os.path.join(hf_dir, "model.safetensors.index.json"))) + with open(os.path.join(hf_dir, "model.safetensors.index.json")) as f: # pylint: disable=unspecified-encoding + safetensor_config = json.load(f) weight_map = safetensor_config['weight_map'] - model_config = json.load(open(os.path.join(hf_dir, "config.json"))) + with open(os.path.join(hf_dir, "config.json")) as f: # pylint: disable=unspecified-encoding + model_config = json.load(f) arch = model_config['architectures'][0] # Check whether training using groupgemm @@ -118,4 +121,4 @@ def convert_checkpoint_cpu(args): parser.add_argument("--groupgemm", action='store_true', help="Whether use groupgemm for training") parser.add_argument("--iter", type=int, required=True, help="which iter to convert") args = parser.parse_args() - convert_checkpoint_cpu(args) + convert_checkpoint_cpu(args_input=args) From 9f3dff6003373b5142cdb765805a0f102a2602fb Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 14 Jul 2025 15:18:34 +0800 Subject: [PATCH 08/10] fix pylint --- chatlearn/offline_ckpt_converter.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chatlearn/offline_ckpt_converter.py b/chatlearn/offline_ckpt_converter.py index d0ae0448..163b7fff 100644 --- a/chatlearn/offline_ckpt_converter.py +++ b/chatlearn/offline_ckpt_converter.py @@ -48,12 +48,12 @@ def check_groupgemm_param(key): } def convert_checkpoint_cpu(args_input): - iter_ = args.iter - hf_dir = args.hf_dir - dist_model_dir = os.path.join(args.ckpt_dir, str(iter_)) - save_dir = args.save_dir + iter_ = args_input.iter + hf_dir = args_input.hf_dir + dist_model_dir = os.path.join(args_input.ckpt_dir, str(iter_)) + save_dir = args_input.save_dir os.makedirs(save_dir, exist_ok=True) - is_groupgemm = args.groupgemm + is_groupgemm = args_input.groupgemm safetensor_file = [] other_file = [] From c38baa3056fc7cd22d48b17698b699d3b4e0066f Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Mon, 14 Jul 2025 17:34:38 +0800 Subject: [PATCH 09/10] Add doc for ckpt transform --- chatlearn/offline_ckpt_converter.py | 4 ++-- docs/en/tutorial/tutorial_grpo_fsdp.md | 17 +++++++++++++++++ docs/zh/tutorial/tutorial_grpo_fsdp.md | 17 +++++++++++++++++ scripts/fsdp_weight_converter_cpu.sh | 7 ------- 4 files changed, 36 insertions(+), 9 deletions(-) delete mode 100755 scripts/fsdp_weight_converter_cpu.sh diff --git a/chatlearn/offline_ckpt_converter.py b/chatlearn/offline_ckpt_converter.py index 163b7fff..6bf85b9e 100644 --- a/chatlearn/offline_ckpt_converter.py +++ b/chatlearn/offline_ckpt_converter.py @@ -53,7 +53,7 @@ def convert_checkpoint_cpu(args_input): dist_model_dir = os.path.join(args_input.ckpt_dir, str(iter_)) save_dir = args_input.save_dir os.makedirs(save_dir, exist_ok=True) - is_groupgemm = args_input.groupgemm + is_groupgemm = args_input.groupgemm == 1 safetensor_file = [] other_file = [] @@ -118,7 +118,7 @@ def convert_checkpoint_cpu(args_input): parser.add_argument("--hf_dir", type=str, required=True, help="Directory to load hf config") parser.add_argument("--ckpt_dir", type=str, required=True, help="Directory to load sharded checkpoint") parser.add_argument("--save_dir", type=str, required=True, help="Directory to save converted hf checkpoint") - parser.add_argument("--groupgemm", action='store_true', help="Whether use groupgemm for training") + parser.add_argument("--groupgemm", type=int, choices=[0, 1], default=0, help="Whether use groupgemm for training") parser.add_argument("--iter", type=int, required=True, help="which iter to convert") args = parser.parse_args() convert_checkpoint_cpu(args_input=args) diff --git a/docs/en/tutorial/tutorial_grpo_fsdp.md b/docs/en/tutorial/tutorial_grpo_fsdp.md index 76423578..eb806c87 100644 --- a/docs/en/tutorial/tutorial_grpo_fsdp.md +++ b/docs/en/tutorial/tutorial_grpo_fsdp.md @@ -50,6 +50,23 @@ runtime_args.log_args_dict.enable_wandb=True runtime_args.log_args_dict.wandb_project="Your-Wandb-Project-Name" ``` +## 模型转化 +Saving FSDP models is time-consuming. Chatlearn provides an offline model conversion feature, which converts FSDP-saved sharded models back into HuggingFace models. The script is as follows: +```bash +export CHATLEARN=$(pwd) +python chatlearn/offline_ckpt_converter.py \ + --hf_dir ${CHATLEARN}/Qwen3-8B/ \ + --ckpt_dir ${CHATLEARN}/output/qwen3-grpo-8b/save_model/policy_trainer \ + --save_dir ${CHATLEARN}/output/qwen3-grpo-8b/save_model/huggingface/ \ + --iter 200 \ + --groupgemm 0 +``` +If you are training an MoE model with groupgemm, please make sure to set: +```bash + --groupgemm 1 +``` +This script will convert the final FSDP sharded model after training back into a HuggingFace model and save it in the path "${CHATLEARN}/output/qwen3-grpo-8b/save_model/huggingface/". + ## FAQ ### How to Speed Up PolicyTrainer Training? 1. Set models.policy_trainer.packing=True and configure models.policy_trainer.max_token_in_packing to the maximum token count that fits GPU memory. diff --git a/docs/zh/tutorial/tutorial_grpo_fsdp.md b/docs/zh/tutorial/tutorial_grpo_fsdp.md index 923c9e0d..2c55aa37 100644 --- a/docs/zh/tutorial/tutorial_grpo_fsdp.md +++ b/docs/zh/tutorial/tutorial_grpo_fsdp.md @@ -50,6 +50,23 @@ runtime_args.log_args_dict.enable_wandb=True runtime_args.log_args_dict.wandb_project="Your-Wandb-Project-Name" ``` +## 模型转化 +FSDP模型保存耗时较高,Chatlearn提供了离线模型转化功能,将FSDP保存的切片模型转化回huggingface模型。脚本如下: +```bash +export CHATLEARN=$(pwd) +python chatlearn/offline_ckpt_converter.py \ + --hf_dir ${CHATLEARN}/Qwen3-8B/ \ + --ckpt_dir ${CHATLEARN}/output/qwen3-grpo-8b/save_model/policy_trainer \ + --save_dir ${CHATLEARN}/output/qwen3-grpo-8b/save_model/huggingface/ \ + --iter 200 \ + --groupgemm 0 +``` +如果你使用groupgemm优化的moe模型训练,请确保设置: +```bash + --groupgemm 1 +``` +这段脚本会将训练完成后的最后一个FSDP切片模型转化回HF模型,并保存在"${CHATLEARN}/output/qwen3-grpo-8b/save_model/huggingface/"路径下 + ## FAQ ### 如何可以加快PolicyTrainer的训练速度? 1. 设置models.policy_trainer.packing=True,并设置models.policy_trainer.max_token_in_packing=可以打满显存的总token数。 diff --git a/scripts/fsdp_weight_converter_cpu.sh b/scripts/fsdp_weight_converter_cpu.sh deleted file mode 100755 index a620dda5..00000000 --- a/scripts/fsdp_weight_converter_cpu.sh +++ /dev/null @@ -1,7 +0,0 @@ -export CHATLEARN=$(pwd) -python chatlearn/offline_ckpt_converter.py \ - --hf_dir ${CHATLEARN}/Qwen3-235B-A22B/ \ - --ckpt_dir ${CHATLEARN}/output/qwen3-grpo-235b-a22b/save_model/policy_trainer \ - --save_dir ${CHATLEARN}/output/qwen3-grpo-8b/save_model/huggingface/ \ - --iter 200 \ - --groupgemm \ No newline at end of file From 977478ef526d04d36ee898b30698bc0159848f1d Mon Sep 17 00:00:00 2001 From: yytang0204 Date: Tue, 15 Jul 2025 14:42:44 +0800 Subject: [PATCH 10/10] support dapo_17k --- chatlearn/models/reward/rule_reward.py | 4 +- .../utils/rule_reward_score/math_dapo.py | 268 ++++++++++++++++++ 2 files changed, 271 insertions(+), 1 deletion(-) create mode 100644 chatlearn/utils/rule_reward_score/math_dapo.py diff --git a/chatlearn/models/reward/rule_reward.py b/chatlearn/models/reward/rule_reward.py index d4d9ce26..e35fa3b5 100644 --- a/chatlearn/models/reward/rule_reward.py +++ b/chatlearn/models/reward/rule_reward.py @@ -18,7 +18,7 @@ import torch from chatlearn import BaseModule -from chatlearn.utils.rule_reward_score import math +from chatlearn.utils.rule_reward_score import math, math_dapo class RuleReward(BaseModule): """rule reward""" @@ -68,5 +68,7 @@ def eval_forward(self, data: Dict) -> Dict: def select_rule_reward_score_fn(self, data_source: str): if data_source in ['openai/gsm8k', 'DigitalLearningGmbH/MATH-lighteval', 'aime24', 'aime25']: return math.compute_score + elif data_source in ['dapo_17k']: + return math_dapo.compute_score else: raise NotImplementedError diff --git a/chatlearn/utils/rule_reward_score/math_dapo.py b/chatlearn/utils/rule_reward_score/math_dapo.py new file mode 100644 index 00000000..29101375 --- /dev/null +++ b/chatlearn/utils/rule_reward_score/math_dapo.py @@ -0,0 +1,268 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Adapted from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/hendrycks_math/utils.py + +import re +from typing import Optional + + +def last_boxed_only_string(string: str) -> Optional[str]: + """Extract the last LaTeX boxed expression from a string. + + Args: + string: Input string containing LaTeX code + + Returns: + The last boxed expression or None if not found + """ + idx = string.rfind("\\boxed{") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + return string[idx : right_brace_idx + 1] if right_brace_idx is not None else None + + +def remove_boxed(s: str) -> str: + """Remove the LaTeX boxed command from a string. + + Args: + s: String with format "\\boxed{content}" + + Returns: + The content inside the boxed command + """ + left = "\\boxed{" + assert s[: len(left)] == left, f"box error: {s}" + assert s[-1] == "}", f"box error: {s}" + return s[len(left) : -1] + + +# Constants for normalization +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] + +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + + +def normalize_final_answer(final_answer: str) -> str: + """Normalize a final answer to a quantitative reasoning question. + + Args: + final_answer: The answer string to normalize + + Returns: + Normalized answer string + """ + final_answer = final_answer.split("=")[-1] + + # Apply substitutions and removals + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract and normalize LaTeX math + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize numbers + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer.strip() + + +def is_correct_minerva( + solution_str: str, gt: str, gt_need_extract: bool = False, answer_pattern: str = r"(?i)Answer\s*:\s*([^\n]+)" +) -> tuple[bool, str]: + """Check if the solution is correct according to Minerva criteria. + + Args: + solution_str: The solution string to check + gt: The ground truth answer + gt_need_extract: Whether the ground truth needs extraction + answer_pattern: Regex pattern to extract the answer + + Returns: + Tuple of (is_correct, normalized_prediction) + """ + # Extract answer from solution + match = re.findall(answer_pattern, solution_str) + extracted_answer = match[-1] if match else "[INVALID]" + pred = normalize_final_answer(extracted_answer) + + # Process ground truth + if gt_need_extract: + gt = normalize_final_answer(remove_boxed(last_boxed_only_string(gt))) + else: + gt = normalize_final_answer(gt) + + return (pred == gt), pred + + +def is_correct_strict_box( + pred: str, gt: str, pause_tokens_index: Optional[list[int]] = None +) -> tuple[int, Optional[str]]: + """Check if the prediction is correct using strict boxed answer criteria. + + Args: + pred: The prediction string + gt: The ground truth answer + pause_tokens_index: Indices of pause tokens + + Returns: + Tuple of (score, extracted_prediction) + """ + # Extract the relevant part of the prediction + if pause_tokens_index is not None: + assert len(pause_tokens_index) == 4 + pred = pred[pause_tokens_index[-1] - 100 :] + else: + pred = pred[-100:] + + # Extract and check the boxed answer + boxed_pred = last_boxed_only_string(pred) + extracted_pred = remove_boxed(boxed_pred) if boxed_pred is not None else None + + return 1 if (extracted_pred == gt) else -1, extracted_pred + + +def verify( + solution_str: str, answer: str, strict_box_verify: bool = False, pause_tokens_index: Optional[list[int]] = None +) -> bool: + """Verify if the solution is correct. + + Args: + solution_str: The solution string to verify + answer: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + True if the solution is correct, False otherwise + """ + if strict_box_verify: + correct, pred = is_correct_strict_box(solution_str, answer, pause_tokens_index) + return correct == 1, pred + + correct, pred = is_correct_minerva(solution_str, answer) + return correct, pred + + +def compute_score( + solution_str: str, + ground_truth: str, + strict_box_verify: bool = False, + pause_tokens_index: Optional[list[int]] = None, +) -> float: + """Compute the reward score for a solution. + + Args: + solution_str: The solution string + ground_truth: The ground truth answer + strict_box_verify: Whether to use strict box verification + pause_tokens_index: Indices of pause tokens + + Returns: + Reward score (1.0 for correct, -1.0 for incorrect) + """ + # Limit solution length for efficiency + solution_str = solution_str[-300:] # The longest answer in MATH-500 has 159 characters + + # Verify the solution + correct, pred = verify(solution_str, ground_truth, strict_box_verify, pause_tokens_index) + + reward = 1.0 if correct else 0.0 + acc = correct + + return reward