From eb698e99220cbd927550682bb6e1be8c86e4e7f5 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Tue, 10 Sep 2024 17:28:44 +0800 Subject: [PATCH 01/26] Support MCore models --- chatlearn/models/vllm/vllm_model.py | 8 +- chatlearn/models/vllm_module.py | 4 +- chatlearn/tools/megatron_checkpoint_utils.py | 94 +++++- chatlearn/utils/megatron_import_helper.py | 8 +- chatlearn/utils/megatron_utils.py | 14 +- chatlearn/utils/vllm_utils.py | 313 +++++++++++++++++- examples/megatron/configs/gpt/base.yaml | 3 +- examples/megatron/configs/llama2/base.yaml | 3 +- examples/megatron/entry/train_reward.py | 27 +- examples/megatron/entry/train_sft.py | 65 +++- examples/megatron/models/base_trainer.py | 5 +- .../megatron/models/mcore_policy_model.py | 264 +++++++++++++++ .../megatron/models/mcore_reward_model.py | 220 ++++++++++++ examples/megatron/models/mcore_value_model.py | 154 +++++++++ .../megatron/models/old_policy_inference.py | 17 +- .../megatron/models/old_value_inference.py | 9 +- examples/megatron/models/policy_model.py | 2 +- examples/megatron/models/policy_trainer.py | 30 +- examples/megatron/models/reference.py | 9 +- examples/megatron/models/reward_inference.py | 43 ++- examples/megatron/models/reward_model.py | 15 - examples/megatron/models/utils.py | 6 +- examples/megatron/models/value_trainer.py | 33 +- examples/megatron/scripts/base_env.sh | 8 + .../scripts/convert_hf_to_megatron.sh | 11 +- .../scripts/convert_megatron_to_hf.sh | 32 +- .../megatron/scripts/train_reward_llama.sh | 13 +- examples/megatron/scripts/train_rlhf_llama.sh | 2 + examples/megatron/scripts/train_sft_llama.sh | 40 ++- .../megatron/tests/run_policy_generation.sh | 1 + 30 files changed, 1344 insertions(+), 109 deletions(-) create mode 100644 examples/megatron/models/mcore_policy_model.py create mode 100644 examples/megatron/models/mcore_reward_model.py create mode 100644 examples/megatron/models/mcore_value_model.py diff --git a/chatlearn/models/vllm/vllm_model.py b/chatlearn/models/vllm/vllm_model.py index 19582354..61bb4e2a 100644 --- a/chatlearn/models/vllm/vllm_model.py +++ b/chatlearn/models/vllm/vllm_model.py @@ -26,6 +26,7 @@ from chatlearn.utils.vllm_utils import ( convert_llama_state_dict_from_megatron_to_vllm, + convert_llama_state_dict_from_mcore_to_vllm, convert_qwen_state_dict_from_megatron_to_vllm, load_checkpoint ) @@ -49,10 +50,13 @@ def load_weights(self): load_checkpoint(self, None, None) torch.distributed.barrier() - def load_state_dict(self, state_dict, strict=True, assign=False): # pylint: disable=unused-argument + def load_state_dict(self, state_dict, strict=True, assign=False): qwen_version = None if isinstance(self.model, LlamaForCausalLM): - convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm + if self.model_args["use_legacy_models"]: + convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm + else: + convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm elif isinstance(self.model, QWenLMHeadModel): qwen_version = 1.0 convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm diff --git a/chatlearn/models/vllm_module.py b/chatlearn/models/vllm_module.py index ad7cbf0d..ddde8e77 100644 --- a/chatlearn/models/vllm_module.py +++ b/chatlearn/models/vllm_module.py @@ -50,7 +50,7 @@ except ImportError: print("Cannot import addtional module for vllm 0.5.1, please install vllm 0.5.1 first.") -from chatlearn.utils.vllm_utils import initialize_vllm, Megatron2LlamaSyncMap, Megatron2QWenSyncMap +from chatlearn.utils.vllm_utils import initialize_vllm, Megatron2LlamaSyncMap, Megatron2QWenSyncMap, MCore2LlamaSyncMap from chatlearn.utils.vllm_utils import get_model, print_rank_0 from .torch_module import TorchModule @@ -536,7 +536,7 @@ def map_src_to_dst(self, src_names, num_src_pipeline_stage, src_pipe_stage): self._to_fix_qkv_ordering_func = fix_query_key_value_ordering sync_map = sync_map_cls(src_names, layer_offset, QwenVersion.v_2.value) elif isinstance(self.model.model, LlamaForCausalLM): - sync_map_cls = Megatron2LlamaSyncMap + sync_map_cls = Megatron2LlamaSyncMap if self.model_args["use_legacy_models"] else MCore2LlamaSyncMap from chatlearn.utils.vllm_utils import fix_qwen_query_key_value_ordering # pylint: disable=import-outside-toplevel self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering sync_map = sync_map_cls(src_names, layer_offset) diff --git a/chatlearn/tools/megatron_checkpoint_utils.py b/chatlearn/tools/megatron_checkpoint_utils.py index 88a1685f..680f1d9a 100644 --- a/chatlearn/tools/megatron_checkpoint_utils.py +++ b/chatlearn/tools/megatron_checkpoint_utils.py @@ -77,7 +77,7 @@ def repair_loader_put_reward(lines): """ return detect_and_insert_code(lines, pattern, new_code, 0, 0) -def repair_saver_put_reward(lines): +def repair_saver_get_reward(lines): pattern = 'if msg != "done":' new_code = \ """ @@ -100,6 +100,69 @@ def repair_saver_put_reward(lines): """ return detect_and_insert_code(lines, pattern, new_code, 0, 0) + +# MCore +def repair_mcore_block_key(source): + source = source.replace('"BERT" : "encoder",', '"BERT" : "encoder", "REWARD" : "decoder"') + return source + +def repair_import_utils(source): + source = source.replace('from utils import get_mcore_transformer_block_key, print_memory_usage', + 'from tools.checkpoint.utils import get_mcore_transformer_block_key, print_memory_usage') + return source + +def repair_loader_mcore_import_error(source): + return repair_import_utils(source) + +def repair_saver_mcore_import_error(source): + source = source.replace('from setter import ModelSetter', + 'from tools.checkpoint.setter import ModelSetter') + return repair_import_utils(source) + +def repair_loader_mcore_model_provider(lines): + # Insert before following code, so line_offset=-2 + # else: + # raise Exception(f'unrecognized model type: {args.model_type}') + pattern = 'unrecognized model type' + new_code = \ +""" +elif args.model_type == 'REWARD': + from examples.megatron.models.mcore_reward_model import model_provider + margs.model_type = ModelType.encoder_or_decoder +""" + indent = -4 + line_offset = -2 + return detect_and_insert_code(lines, pattern, new_code, indent, line_offset) + +def repair_saver_mcore_model_provider(lines): + return repair_loader_mcore_model_provider(lines) + +def repair_loader_mcore_put_reward(lines): + return repair_loader_put_reward(lines) + +def repair_saver_mcore_get_reward(lines): + pattern = 'if msg != "done":' + new_code = \ +""" +if msg != "done" and msg["name"] == "pooler_head": + if not hasattr(models[pp_rank][0][0], 'pooler_head'): + print("ERROR: got a pooler_head, but model does not have one") + exit(1) + print("received pooler_head") + head_weight1 = msg.pop("weight1") + head_bias1 = msg.pop("bias1") + head_weight2 = msg.pop("weight2") + head_bias2 = msg.pop("bias2") + for model in pp_local_models: + model.pooler_head.dense1.weight.data.copy_(head_weight1) + model.pooler_head.dense1.bias.data.copy_(head_bias1) + model.pooler_head.dense2.weight.data.copy_(head_weight2) + model.pooler_head.dense2.bias.data.copy_(head_bias2) + check_message(msg) + msg = queue_get() +""" + return detect_and_insert_code(lines, pattern, new_code, 0, 0) + def exist_checkpoint_util(): spec = importlib.util.find_spec('tools.checkpoint.util') return spec is not None @@ -131,8 +194,24 @@ def repair_code(self, source, module_name): elif module_name == 'saver_megatron': lines = source.split('\n') lines = repair_saver_model_provider(lines) - lines = repair_saver_put_reward(lines) + lines = repair_saver_get_reward(lines) source = '\n'.join(lines) + elif module_name == 'loader_mcore': + source = repair_loader_mcore_import_error(source) + lines = source.split('\n') + lines = repair_loader_mcore_model_provider(lines) + lines = repair_loader_mcore_put_reward(lines) + source = '\n'.join(lines) + elif module_name == 'saver_mcore': + source = repair_saver_mcore_import_error(source) + lines = source.split('\n') + lines = repair_saver_mcore_model_provider(lines) + lines = repair_saver_mcore_get_reward(lines) + source = '\n'.join(lines) + elif module_name == 'utils': + source = repair_mcore_block_key(source) + else: + raise RuntimeError(f"Unrecognized module_name {module_name}") return source def load_module(self, name): @@ -167,7 +246,10 @@ def load_module(self, name): # put the loaded module into sys.modules so that if the module is imported # again it could be found. sys.modules[name] = module - if 'loader_megatron' in name or 'saver_megatron' in name: + if ('loader_megatron' in name + or 'saver_megatron' in name + or 'loader_mcore' in name + or 'saver_mcore' in name): sys.modules[module_name] = module # return the module itself so that it could be used @@ -182,8 +264,12 @@ def load_module(self, name): util.main() else: sys.meta_path.insert(-1, CheckpointUtilsImporter('tools.checkpoint.convert', \ - 'tools.checkpoint.loader_megatron', 'tools.checkpoint.saver_megatron')) + 'tools.checkpoint.loader_megatron', 'tools.checkpoint.saver_megatron', \ + 'tools.checkpoint.loader_mcore', 'tools.checkpoint.saver_mcore', \ + 'tools.checkpoint.utils')) from tools.checkpoint import loader_megatron, saver_megatron # pylint: disable=unused-import + from tools.checkpoint import utils # pylint: disable=unused-import + from tools.checkpoint import loader_mcore, saver_mcore # pylint: disable=unused-import from tools.checkpoint import convert convert.main() # pylint: enable=wildcard-import,exec-used diff --git a/chatlearn/utils/megatron_import_helper.py b/chatlearn/utils/megatron_import_helper.py index 7cecaace..50a74bfe 100644 --- a/chatlearn/utils/megatron_import_helper.py +++ b/chatlearn/utils/megatron_import_helper.py @@ -30,13 +30,17 @@ except ImportError: from megatron.training import arguments from megatron.training import get_args - from megatron.training import get_num_microbatches from megatron.training import get_timers from megatron.training import get_tokenizer from megatron.training import is_last_rank from megatron.training import print_rank_0 from megatron.training import print_rank_last - from megatron.training import update_num_microbatches + try: + from megatron.training import get_num_microbatches + from megatron.training import update_num_microbatches + except ImportError: + from megatron.core.num_microbatches_calculator import get_num_microbatches + from megatron.core.num_microbatches_calculator import update_num_microbatches # megatron.arguments.* try: diff --git a/chatlearn/utils/megatron_utils.py b/chatlearn/utils/megatron_utils.py index 6ad99e0a..29894536 100644 --- a/chatlearn/utils/megatron_utils.py +++ b/chatlearn/utils/megatron_utils.py @@ -63,7 +63,9 @@ def build_pipeline_layer_name_mapping(src_layers_per_stage, src_rank, map_interv if requires_grad: if not partition_param.requires_grad: continue - if src_name.endswith("word_embeddings.weight") and "language_model" not in src_name: + if src_name.endswith("word_embeddings.weight") \ + and "language_model" not in src_name \ + and hasattr(model, "language_model"): # See comment in MegatronModule.initialize_word_embeddings() if not tgt_last_stage: tgt_name = src_name.replace("word_embeddings.weight", "language_model.embedding.word_embeddings.weight") @@ -182,8 +184,14 @@ def load_checkpoint(*_args, **kwargs): if 'pooler_head' in key: model_type = "REWARD" break - cmd = f"python {script_path} --model-type {model_type} --load-dir {args.load} " + \ - f"--save-dir {save_dir} --target-tensor-parallel-size {target_tp} --target-pipeline-parallel-size {target_pp}" + if args.use_legacy_models: + cmd = f"python {script_path} --model-type {model_type} --load-dir {args.load} " + \ + f"--save-dir {save_dir} --target-tensor-parallel-size {target_tp} " + \ + f"--target-pipeline-parallel-size {target_pp}" + else: + cmd = f"python {script_path} --model-type {model_type} --loader mcore --load-dir {args.load} " + \ + f"--saver mcore --save-dir {save_dir} --target-tensor-parallel-size {target_tp} " + \ + f"--target-pipeline-parallel-size {target_pp}" logger.info(f"Transforming checkpoint for new parallel strategies {cmd}") subprocess.run(cmd, shell=True, check=True) torch.distributed.barrier() diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index 7ff91ff6..239dc6f3 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -98,6 +98,13 @@ } +mcore_to_transformers = { + "self_attention.linear_proj": ".self_attn.o_proj.", + "mlp.linear_fc1": ".mlp.gate_up_proj.", + "mlp.linear_fc2": ".mlp.down_proj.", +} + + class ParameterSyncMap: """Base ParameterSyncMap.""" def __init__(self, src_names, layer_offset): @@ -150,7 +157,7 @@ def __init__(self, src_names, layer_offset): src_prefix = "module.module.language_model" dst_prefix = "model.model" # The regex to extract layer names. - self.layer_re = re.compile(f"{src_prefix}.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") # pylint: disable=anomalous-backslash-in-string + self.layer_re = re.compile(rf"{src_prefix}.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") self.src_prefix = src_prefix self.dst_prefix = dst_prefix self._embedding_sync_map = { @@ -230,6 +237,94 @@ def map_src_to_dst(self): self._dst_names.append(layer_name + out_name) +class MCore2LlamaSyncMap(ParameterSyncMap): + """sync map:megatron-core to llama transformer""" + def __init__(self, src_names, layer_offset): + src_prefix = "module.module" + dst_prefix = "model.model" + # The regex to extract layer names. + self.layer_re = re.compile(rf"{src_prefix}.decoder.layers\.(\d+)\.([a-z0-9_.]+)[\._]([a-z]+)") + self.src_prefix = src_prefix + self.dst_prefix = dst_prefix + # vLLM skips loading rotary_pos_emb and re-initializes it. Thus, we don't synchronize it from MCore to vllm. + self._embedding_sync_map = { + f"{src_prefix}.embedding.word_embeddings.weight": f"{dst_prefix}.embed_tokens.weight", + } + self._layer_sync_map = { + "self_attention.linear_proj": ".self_attn.o_proj.", + "mlp.linear_fc1": ".mlp.gate_up_proj.", + "mlp.linear_fc2": ".mlp.down_proj.", + } + self._final_layer_sync_map = { + f"{src_prefix}.decoder.final_layernorm.weight": f"{dst_prefix}.norm.weight", + f"{src_prefix}.output_layer.weight": "model.lm_head.weight" + } + self._concat_params_dict = None + self._to_fix_act_ordering_dict = None + self._to_fix_qkv_ordering_dict = { + "modules": [ + "self_attention.linear_qkv", + ], + "layer_re": self.layer_re + } + super().__init__(src_names, layer_offset) + + def map_src_to_dst(self): + for src_name in self.src_names: + # convert word embeddings. + if src_name in self.embedding_sync_map: + self._dst_names.append(self.get_dst_name(self.embedding_sync_map, src_name)) + continue + + # final layer + if src_name in self.final_layer_sync_map: + self._dst_names.append(self.get_dst_name(self.final_layer_sync_map, src_name)) + continue + + m = self.layer_re.match(src_name) + # Stop if that's not a layer + if m is None: + raise RuntimeError(f"expect src_name ({src_name}) to be a layer") + + # The index of the layer. + layer_idx = int(m.group(1)) + self.layer_offset + + # The name of the operation. + op_name = m.group(2) + # Is it a weight or a bias? + weight_or_bias = m.group(3) + # The name of the layer. + layer_name = f"{self.dst_prefix}.layers.{layer_idx}" + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layer_norm") and weight_or_bias == 'weight': + if op_name == "self_attention.linear_qkv.layer_norm": + ln_name = "input_layernorm" + elif op_name == "mlp.linear_fc1.layer_norm": + ln_name = "post_attention_layernorm" + else: + assert False, f"expect op_name ({op_name}) to be layer norm" + self._dst_names.append(layer_name + "." + ln_name + "." + weight_or_bias) + + # Transpose the QKV matrix. + elif op_name == "self_attention.linear_qkv" and weight_or_bias == 'weight': + self._dst_names.append(layer_name + ".self_attn.qkv_proj.weight") + + # Transpose the weights. + elif weight_or_bias == "weight": + out_name = self.get_dst_name(self.layer_sync_map, op_name) + self._dst_names.append(layer_name + out_name + "weight") + + # Ignore biases and extra_states. + elif weight_or_bias in ["bias", "_extra_state"]: + pass + + # Copy the rest. + else: + out_name = self.get_dst_name(self.layer_sync_map, op_name) + self._dst_names.append(layer_name + out_name) + + class Megatron2QWenSyncMap(ParameterSyncMap): """sync map:megatron to qwen transformer""" def __init__(self, src_names, layer_offset, qwen_version=QwenVersion.v_1.value): @@ -255,7 +350,7 @@ def __init__(self, src_names, layer_offset, qwen_version=QwenVersion.v_1.value): raise RuntimeError(f"Unsupported qwen version {qwen_version}, only 1.0 or 2.0 for now.") # The regex to extract layer names. - self.layer_re = re.compile(f"{src_prefix}.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") # pylint: disable=anomalous-backslash-in-string + self.layer_re = re.compile(rf"{src_prefix}.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") self.src_prefix = src_prefix self.dst_prefix = dst_prefix self._embedding_sync_map = { @@ -573,6 +668,22 @@ def get_megatron_sharded_states(args, tp_size, pp_size, pp_rank): return tp_state_dicts +def load_llama_state_dict(args): + # Load original state dict from Megatron-LM checkpoint. + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] + + for root, dirnames, _ in os.walk(args["load"]): + for dirname in dirnames: + if dirname in possible_sub_dirs: + rank0_checkpoint_name = glob.glob(os.path.join(root, dirname) + "/model*.pt") + args["load"] = root + rank0_checkpoint_path = rank0_checkpoint_name[0] + + print(f"Loading Megatron checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + return state_dict + + def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version=None): """Convert NVIDIA Megatron-LM state_dict to vLLM llama state_dict. @@ -580,21 +691,12 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version args (argparse.Namespace): the arguments to the script """ assert qwen_version is None, f"Expect qwen_version is None for Llama, while {qwen_version}" - # Load original state dict from Megatron-LM checkpoint. tp_rank = mpu.get_tensor_model_parallel_rank() pp_rank = get_pipeline_model_parallel_rank() assert pp_rank == 0 - possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000"] - for root, dirnames, _ in os.walk(args["load"]): - for dirname in dirnames: - if dirname in possible_sub_dirs: - rank0_checkpoint_name = glob.glob(os.path.join(root, dirname) + "/model*.pt") - args["load"] = root - rank0_checkpoint_path = rank0_checkpoint_name[0] + state_dict = load_llama_state_dict(args) - print(f"Loading Megatron-LM checkpoint arguments from: {rank0_checkpoint_path}") - state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") megatron_args = state_dict.get("args", None) if "checkpoint_version" in state_dict.keys(): checkpoint_version = state_dict["checkpoint_version"] @@ -620,7 +722,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version hidden_size_per_head = hf_config.hidden_size // hf_config.num_attention_heads # The regex to extract layer names. - layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") # pylint: disable=anomalous-backslash-in-string + layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") # Convert. print("Start to convert...") @@ -755,6 +857,179 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version return output_state_dict +def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=None): + """Convert NVIDIA Megatron-Core state_dict to vLLM llama state_dict. + + Args: + args (argparse.Namespace): the arguments to the script + """ + assert qwen_version is None, f"Expect qwen_version is None for Llama, while {qwen_version}" + tp_rank = mpu.get_tensor_model_parallel_rank() + pp_rank = get_pipeline_model_parallel_rank() + assert pp_rank == 0 + + state_dict = load_llama_state_dict(args) + + megatron_args = state_dict.get("args", None) + if "checkpoint_version" in state_dict.keys(): + checkpoint_version = state_dict["checkpoint_version"] + else: + checkpoint_version = 0.0 + if megatron_args is None: + raise ValueError( + "Megatron-LM checkpoint does not contain arguments. This utility only supports Megatron-LM checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-LM checkpoint along with all the megatron" + " arguments to use this utility." + ) + + output_state_dict = {} + + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + assert pp_size == 1 + # The number of heads. + heads = hf_config.num_attention_heads // tp_size + # The hidden_size per head. + hidden_size_per_head = hf_config.hidden_size // hf_config.num_attention_heads + + # The regex to extract layer names. + layer_re = re.compile(r"decoder.layers\.(\d+)\.([a-z0-9_.]+)[\._]([a-z]+)") + + # Convert. + print("Start to convert...") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + # tp_state_dicts: list of state dict for each tp rank + # tp_state_dicts[0]: a state dict for tp rank 0 + # |-keys: dict_keys(['args', 'checkpoint_version', 'iteration', 'model', ...]) + # |-tp_state_dicts[0]['model'] + # |-keys: ['embedding.word_embeddings.weight', + # 'decoder.layers.0.self_attention.core_attention.fused_attention._extra_state', + # 'decoder.layers.0.self_attention.linear_proj.weight', + # 'decoder.layers.0.self_attention.linear_proj._extra_state', + # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight', + # 'decoder.layers.0.self_attention.linear_qkv.weight', + # 'decoder.layers.0.self_attention.linear_qkv._extra_state', + # 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight', + # 'decoder.layers.0.mlp.linear_fc1.weight', + # 'decoder.layers.0.mlp.linear_fc1._extra_state', + # 'decoder.layers.0.mlp.linear_fc2.weight', + # 'decoder.layers.0.mlp.linear_fc2._extra_state', + # ... + # 'decoder.final_layernorm.weight', + # 'output_layer.weight', + # 'output_layer._extra_state' + # Convert and store the position embeddings. + position_embeddings = tp_state_dicts[0]['model'].get("embedding.position_embeddings.weight", None) + if position_embeddings: + output_state_dict["transformer.position_embeddings.weight"] = position_embeddings.to(hf_config.torch_dtype) + + # Convert and store the word embeddings. + word_embeddings = tp_state_dicts[tp_rank]['model'].get("embedding.word_embeddings.weight", None) + word_embeddings = word_embeddings.to(hf_config.torch_dtype) + output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + # Reset the vocab size + hf_config.vocab_size = word_embeddings.shape[0] + + # Transformer Layers + print("Converting transformer layers") + + for key, val in tp_state_dicts[tp_rank]['model'].items(): + if val is None: + assert 'extra_state' in key, "weight/bias shouldn't be None except for extra_state in mcore" + continue + if "_extra_state" in key: + continue + + # Match the name + layer_match_res = layer_re.match(key) + # Skip if that's not a layer + if layer_match_res is None: + continue + # The index of the layer + layer_idx = int(layer_match_res.group(1)) + # The name of the operation. + op_name = layer_match_res.group(2) + # Is it a weight or a bias? + weight_or_bias = layer_match_res.group(3) + # The name of the layer + layer_name = f"model.model.layers.{layer_idx}" + + params = val.to(hf_config.torch_dtype) + + # For layernorm(s), simply store the layer norm. + if op_name.endswith("layer_norm") and weight_or_bias == 'weight': + if op_name == "self_attention.linear_qkv.layer_norm": + ln_name = "input_layernorm" + elif op_name == "mlp.linear_fc1.layer_norm": + ln_name = "post_attention_layernorm" + else: + assert False, f"Unrecognized op_name {op_name} for layer norm" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif op_name == "self_attention.linear_qkv" and weight_or_bias == 'weight': + input_shape = params.size() + shape = (heads, hidden_size_per_head, 3) + input_shape[1:] + division = reduce(operator.mul, shape, 1) + num_elements = params.numel() + if num_elements != division: + # model with gqa dont need to fix qkv ordering. + output_state_dict[layer_name + ".self_attn.qkv_proj.weight"] = params + else: + out_val = fix_qwen_query_key_value_ordering( + params, checkpoint_version, 3, heads, hidden_size_per_head + ) + # Megatron stores (3*D) x D but transformers-GPT2 expects D x 3*D. + out_val = out_val.contiguous() + # Store. + output_state_dict[layer_name + ".self_attn.qkv_proj.weight"] = out_val + + # Transpose the bias. + elif op_name == "self_attention.linear_qkv" and weight_or_bias == "bias": + out_val = fix_qwen_query_key_value_ordering( + params, checkpoint_version, 3, heads, hidden_size_per_head + ) + # Store. No change of shape. + output_state_dict[layer_name + ".self_attn.qkv_proj.bias"] = out_val + + # Transpose the weights. + elif weight_or_bias == "weight": + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params + + # Copy the bias. + # Ignore them + elif weight_or_bias == "bias": + pass + + # Copy the Rotary Embedding + else: + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + out_name] = params + + if hf_config.num_hidden_layers != (layer_idx + 1): + raise ValueError(f"Expected {hf_config.num_hidden_layers} layers but found {layer_idx + 1}") + + # The final layernorm. + print("Converting final layernorm") + final_norm_weight = tp_state_dicts[0]['model'].get("decoder.final_layernorm.weight", None) + output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + params = tp_state_dicts[tp_rank]['model'].get('output_layer.weight', None) + output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + + # It should be done! + print("Conversion from Megatron-Core to Transformers is done!") + + return output_state_dict + def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version=QwenVersion.v_1.value): # The converted output model. @@ -843,7 +1118,7 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version= n_positions = hf_config.max_position_embeddings # The regex to extract layer names. - layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") # pylint: disable=anomalous-backslash-in-string + layer_re = re.compile(r"layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)") # Extract the layers. gate_up_proj = {} @@ -1061,8 +1336,14 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri if not os.path.exists(save_dir): if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): model_type = "GPT" - cmd = f"python {script_path} --model-type {model_type} --load-dir {load_dir} " + \ - f"--save-dir {save_dir} --target-tensor-parallel-size {target_tp} --target-pipeline-parallel-size {target_pp}" + if args.get("use_legacy_models"): + cmd = f"python {script_path} --model-type {model_type} --load-dir {args.get('load')} " + \ + f"--save-dir {save_dir} --target-tensor-parallel-size {target_tp} " + \ + f"--target-pipeline-parallel-size {target_pp}" + else: + cmd = f"python {script_path} --model-type {model_type} --loader mcore --load-dir {args.get('load')} " + \ + f"--saver mcore --save-dir {save_dir} --target-tensor-parallel-size {target_tp} " + \ + f"--target-pipeline-parallel-size {target_pp}" subprocess.run(cmd, shell=True, check=True) torch.distributed.barrier() args[load_arg] = save_dir diff --git a/examples/megatron/configs/gpt/base.yaml b/examples/megatron/configs/gpt/base.yaml index ab94a153..1a1de235 100644 --- a/examples/megatron/configs/gpt/base.yaml +++ b/examples/megatron/configs/gpt/base.yaml @@ -48,4 +48,5 @@ adaptive_parallel_strategy_on_checkpoint: True trainer_engine: ${trainer_engine:rlhf} attention_softmax_in_fp32: True -transformer_impl: local +transformer_impl: ${transformer_impl:local} +use_legacy_models: ${use_legacy_models:True} diff --git a/examples/megatron/configs/llama2/base.yaml b/examples/megatron/configs/llama2/base.yaml index 176492e0..cc598bbe 100644 --- a/examples/megatron/configs/llama2/base.yaml +++ b/examples/megatron/configs/llama2/base.yaml @@ -12,7 +12,7 @@ group_query_attention: ${group_query_attention:False} add_bias_linear: False swiglu: True attention_softmax_in_fp32: True -transformer_impl: local +transformer_impl: ${transformer_impl:local} bf16: True @@ -70,3 +70,4 @@ adaptive_parallel_strategy_on_checkpoint: True log_interval: ${log_interval:10} distributed_timeout_minutes: 30 make_vocab_size_divisible_by: 32 +use_legacy_models: ${use_legacy_models:True} diff --git a/examples/megatron/entry/train_reward.py b/examples/megatron/entry/train_reward.py index 9809a9c6..af8fb228 100644 --- a/examples/megatron/entry/train_reward.py +++ b/examples/megatron/entry/train_reward.py @@ -30,7 +30,32 @@ from megatron.training.utils import get_ltor_masks_and_position_ids from examples.megatron.data.reward_dataset import build_train_valid_test_datasets_for_rm -from examples.megatron.models.reward_model import model_provider +from examples.megatron.models.reward_model import RewardModel as LegacyRewardModel +from examples.megatron.models.mcore_reward_model import MCoreRewardModel + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + args = get_args() + if args.use_legacy_models: + model = LegacyRewardModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + score_dimension=1, + ) + else: + model = MCoreRewardModel( + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + score_dimension=1, + ) + + return model def get_tensor_shapes_reward( # pylint: disable=unused-argument diff --git a/examples/megatron/entry/train_sft.py b/examples/megatron/entry/train_sft.py index d8c653ce..f6e74670 100644 --- a/examples/megatron/entry/train_sft.py +++ b/examples/megatron/entry/train_sft.py @@ -23,10 +23,16 @@ from megatron.training import print_rank_0 from megatron.core import tensor_parallel from megatron.core.enums import ModelType -from megatron.legacy.model import GPTModel +from megatron.legacy.model import GPTModel as LegacyGPTModel +from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.training import pretrain from megatron.training.utils import average_losses_across_data_parallel_group from megatron.training.utils import get_ltor_masks_and_position_ids +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.spec_utils import import_module from examples.megatron.data.sft_dataset import build_train_valid_test_datasets from examples.megatron.models.utils import has_config_in_args @@ -36,23 +42,54 @@ def model_provider(pre_process=True, post_process=True): """Build the model.""" print_rank_0('building GPT model ...') - if has_config_in_args(GPTModel): + args = get_args() + + if args.use_legacy_models: + if has_config_in_args(LegacyGPTModel): + from megatron.training.arguments import core_transformer_config_from_args # pylint: disable=import-outside-toplevel + + config = core_transformer_config_from_args(args) + model = LegacyGPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + else: + model = LegacyGPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process + ) + else: + # MCore models consistantly have `config` from megatron.training.arguments import core_transformer_config_from_args # pylint: disable=import-outside-toplevel - args = get_args() + config = core_transformer_config_from_args(args) - model = GPTModel( - config, - num_tokentypes=0, - parallel_output=True, + + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if args.transformer_impl == "transformer_engine": + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + else: + transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm) + + model = MCoreGPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, pre_process=pre_process, - post_process=post_process - ) - else: - model = GPTModel( - num_tokentypes=0, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, parallel_output=True, - pre_process=pre_process, - post_process=post_process + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base ) return model diff --git a/examples/megatron/models/base_trainer.py b/examples/megatron/models/base_trainer.py index 7dc10c49..427b097e 100644 --- a/examples/megatron/models/base_trainer.py +++ b/examples/megatron/models/base_trainer.py @@ -19,7 +19,10 @@ from megatron.training import get_args from megatron.training import get_timers from megatron.training import get_tokenizer -from megatron.training import get_num_microbatches +try: + from megatron.training import get_num_microbatches +except ImportError: + from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.training import print_rank_0 from megatron.core.enums import ModelType try: diff --git a/examples/megatron/models/mcore_policy_model.py b/examples/megatron/models/mcore_policy_model.py new file mode 100644 index 00000000..538ebbf3 --- /dev/null +++ b/examples/megatron/models/mcore_policy_model.py @@ -0,0 +1,264 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""policy model""" + +import torch + +from megatron.training import get_args +from megatron.core import tensor_parallel +from megatron.training.global_vars import get_tokenizer +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.legacy.model.language_model import parallel_lm_logits +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core.transformer.spec_utils import import_module +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + +from chatlearn.models.megatron.ops.policy_gradient import tensor_decomp_pg_loss +from .utils import get_advantages_and_returns, get_eos_id +from .constants import TrainerEngine +from .constants import select_actions_from_right_padded + + +class MCorePolicyModel(MCoreGPTModel): + """PolicyModel for MCore""" + + def __init__(self, + parallel_output=True, + pre_process=True, + post_process=True, + stats=None): + self.args = get_args() + use_te = self.args.transformer_impl == "transformer_engine" + config = core_transformer_config_from_args(self.args) + + if self.args.spec is not None: + transformer_layer_spec = import_module(self.args.spec) + else: + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + self.args.num_experts, + self.args.moe_grouped_gemm, + self.args.qk_layernorm + ) + else: + transformer_layer_spec = get_gpt_layer_local_spec( + self.args.num_experts, + self.args.moe_grouped_gemm, + self.args.qk_layernorm + ) + + super().__init__( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=self.args.padded_vocab_size, + max_sequence_length=self.args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=self.args.fp16_lm_cross_entropy, + parallel_output=parallel_output, + share_embeddings_and_output_weights=not self.args.untie_embeddings_and_output_weights, + position_embedding_type=self.args.position_embedding_type, + rotary_percent=self.args.rotary_percent, + rotary_base=self.args.rotary_base + ) + + self.tokenizer = get_tokenizer() + self.stats = stats + + def forward_lm(self, input_ids, position_ids, attention_mask, inference_params=None): + if self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # intermediate stage of pipeline + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run decoder. + lm_output = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + + return lm_output + + def forward(self, all_token_ids, all_position_ids, all_token_attention_mask, training_inputs=None, + inference_params=None, inference_config=None): + hiddens = self.forward_lm(all_token_ids, all_position_ids, + all_token_attention_mask, inference_params=inference_params) # [b, s, v] + # note in middle pipeline, this all_token_logits is just a hidden + if self.post_process: + # is last pipeline stage, if inference return the last logits. if training, return the loss + use_parallel_output = inference_config["parallel_output"] if inference_config is not None and \ + "parallel_output" in inference_config else self.parallel_output + + if inference_config is not None and "DPO_labels" in inference_config: + assert get_args().trainer_engine in [TrainerEngine.DPO, TrainerEngine.ONLINE_DPO] + if training_inputs is None: + training_inputs = {} + training_inputs["labels"] = inference_config["DPO_labels"] + if get_args().trainer_engine == TrainerEngine.DPO: + assert "prompt_id_lens" in inference_config + assert "orig_mask" in inference_config + training_inputs["prompt_id_lens"] = inference_config["prompt_id_lens"] + all_token_attention_mask = inference_config["orig_mask"] + use_parallel_output = False + return self.post_language_model_processing( + hiddens, training_inputs, + self.shared_embedding_or_output_weight() if self.share_embeddings_and_output_weights else self.output_layer.weight, + use_parallel_output, + attention_mask=all_token_attention_mask) + else: + return hiddens + + def post_process_rlhf(self, training_inputs, all_token_logits): + old_logprobs = training_inputs['action_logprobs'] # [b, responses size] + old_values = training_inputs['action_values'] # [b, responses size] + old_rewards = training_inputs['action_rewards'] # [b, responses size] + response_length = old_rewards.shape[1] + + all_token_ids = training_inputs["all_token_ids_right_padded"] + + # For a proper positional encoding in case of left padding + advantages, returns = get_advantages_and_returns(self.args, + old_values, old_rewards, response_length + ) + assert advantages.size(1) == returns.size(1) == response_length + # start = query_tensors.shape[1] - 1 for left padded + # end = action_start + response_length + # Note the token logits to get loss is only the actions. query doesn't have loss. + + # all_token_ids = [pad, q1, q2, q3, a1, a2, a3, pad, pad] + # [pad, q1, q2, q3, a1, a2, a3, a4, a5] + # start = 4-1 = 3, end = 3 + 5 = 8 + # action_loss_mask = notpad(q3, a1, a2, a3, pad,]), notpad([q3, a1, a2, a3, a4], ) + # action_token_logits = logits(q3, a1, a2, a3, pad), logits(q3, a1, a2, a3, a4) + # action_ids = [a1, a2, a3, pad, pad], [a1, a2, a3, a4, a5] + + action_loss_mask = select_actions_from_right_padded(ts=training_inputs["all_token_loss_mask"], + action_starts=training_inputs["action_starts"] - 1, + # because align iwth logits index + response_size=response_length, + pad_value=0, dim=-1).contiguous() + + # because we want the logits from the previous token + # because it's -1 at top and then action -1 it hsould remain in bound + action_token_logits = select_actions_from_right_padded(ts=all_token_logits[:, :-1, :], + action_starts=training_inputs["action_starts"] - 1, + response_size=response_length, + pad_value=1.0, dim=-2).contiguous() + action_ids = select_actions_from_right_padded(ts=all_token_ids, + action_starts=training_inputs["action_starts"], + response_size=response_length, + pad_value=get_eos_id(self.tokenizer), dim=-1).contiguous() + + loss = tensor_decomp_pg_loss(self.args, + action_token_logits=action_token_logits, # [b,response size] + action_ids=action_ids, # [b, response size] + action_loss_mask=action_loss_mask, # [b, response size] + old_logprobs=old_logprobs, # [b, response size] + advantages=advantages, # [b, response size] + stats=self.stats) # [b, response_size] remove last logit because it's EOS + + self.approx_kl = self.stats["policy/approx_kl"] # Update kl controller stats + return loss.contiguous() # [b,response_size] + + + def post_process_dpo(self, logits, training_inputs, attention_mask, average_log_prob=False): + assert "labels" in training_inputs and training_inputs['labels'] is not None + labels = training_inputs['labels'] + prompt_id_lens = training_inputs['prompt_id_lens'] + assert logits.shape[:-1] == labels.shape, \ + f"Mismatch tensor shape between logits.shape[:-1] ({logits.shape[:-1]}) and labels.shape ({labels.shape})" + loss_masks = attention_mask.clone().bool() + loss_masks = loss_masks.squeeze(1) + for mask, source_len in zip(loss_masks, prompt_id_lens): + mask[:source_len] = False + labels[loss_masks == False] = 0 # pylint: disable=singleton-comparison + + loss_masks = loss_masks[:, 1:] + logits = logits[:, 1:, :] + labels = labels[:, 1:] + + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + + if average_log_prob: + return (per_token_logps * loss_masks).sum(-1) / loss_masks.sum(-1) + else: + return (per_token_logps * loss_masks).sum(-1) + + def post_process_online_dpo(self, sbv_all_token_logits, training_inputs): + assert "labels" in training_inputs and training_inputs['labels'] is not None + CE_loss = self.cross_entropy_loss(sbv_all_token_logits, + training_inputs['labels'], + self.args.fp16_lm_cross_entropy) + return CE_loss + + def post_language_model_processing(self, hiddens, training_inputs, logit_weights, + parallel_output, attention_mask=None): + + # is last pipeline stage, if inference return the last logits. if training, return the loss + + inference_only = training_inputs is None + + # Output. Format [s b v] + all_token_logits = parallel_lm_logits( + hiddens, + logit_weights, + parallel_output) + + sbv_all_token_logits = all_token_logits + all_token_logits = all_token_logits.transpose(0, 1).contiguous() + + if inference_only: + # [s b h] => [b s h] + # TODO do we need to transpose???? + if self.args.trainer_engine == TrainerEngine.DPO: + return self.post_process_dpo(all_token_logits, training_inputs, attention_mask) + return all_token_logits + else: + if self.args.trainer_engine == TrainerEngine.DPO: + return self.post_process_dpo(all_token_logits, training_inputs, attention_mask) + elif self.args.trainer_engine == TrainerEngine.RLHF.value: + return self.post_process_rlhf(training_inputs, all_token_logits) + elif self.args.trainer_engine == TrainerEngine.ONLINE_DPO: + return self.post_process_online_dpo(sbv_all_token_logits, training_inputs) + + def cross_entropy_loss(self, sbv_all_token_logits, labels, fp16_lm_cross_entropy): + #all_token_logits is [s,b,vp] + labels = labels.transpose(0, 1).contiguous() #[s,b] + # if flash_cross_entropy is not None: + # loss = flash_cross_entropy(output.flatten(0, 1), labels.flatten()).view(*labels.size()) + + if fp16_lm_cross_entropy: + assert sbv_all_token_logits.dtype == sbv_all_token_logits.half + loss = tensor_parallel.vocab_parallel_cross_entropy(sbv_all_token_logits, labels) + else: + loss = tensor_parallel.vocab_parallel_cross_entropy(sbv_all_token_logits.float(), labels) + + # [s b] => [b, s] + loss = loss.transpose(0, 1).contiguous() + return loss diff --git a/examples/megatron/models/mcore_reward_model.py b/examples/megatron/models/mcore_reward_model.py new file mode 100644 index 00000000..f0a44bda --- /dev/null +++ b/examples/megatron/models/mcore_reward_model.py @@ -0,0 +1,220 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""reward model""" + +import torch +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args +from megatron.core import tensor_parallel +from megatron.core.models.gpt import GPTModel as MCoreGPTModel +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.utils import get_linear_layer +from megatron.core.transformer.spec_utils import import_module +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) + + +class LinearPooler(MegatronModule): + """Pooler layer. + + Pool hidden states of a specific token (for example start of the + sequence) and add a linear transformation followed by a tanh. + + Arguments: + hidden_size: hidden size + init_method: weight initialization method for the linear layer. + bias is set to zero. + """ + + def __init__(self, config, score_dimensions): + super().__init__(config=config) + args = get_args() + hidden_size = config.hidden_size + init_method = config.init_method + self.dense1 = get_linear_layer(hidden_size, hidden_size, init_method, args.perform_initialization) + self.dense2 = get_linear_layer(hidden_size, score_dimensions, init_method, args.perform_initialization) + self.sequence_parallel = args.sequence_parallel + + def forward(self, hidden_states, sequence_indices=None): + # hidden_states: [s, b, h] + # sequence_index: index of the token to pool. + + # gather data along sequence dimensions + # same pooler is run on all tensor parallel nodes + if self.sequence_parallel: + hidden_states = tensor_parallel.gather_from_sequence_parallel_region( + hidden_states, # [s, b, h] + tensor_parallel_output_grad=False) + + if sequence_indices is not None: + selected_hidden = torch.index_select(hidden_states, 0, sequence_indices) + selected_hidden = selected_hidden.diagonal(dim1=0, dim2=1).T + pooled = self.dense2(torch.nn.functional.relu(self.dense1(selected_hidden))) + else: + selected_hidden = hidden_states # [s, b, h] + pooled = self.dense2(torch.nn.functional.relu(self.dense1(selected_hidden))).squeeze(2) # [s, b, scoredim] + + return pooled + + +class MCoreRewardModel(MCoreGPTModel): + """RewardModel for MCore""" + + def __init__(self, + parallel_output=True, + pre_process=True, + post_process=True, + pooler_head=LinearPooler, + score_dimension=1): + self.args = get_args() + use_te = self.args.transformer_impl == "transformer_engine" + self.config = core_transformer_config_from_args(self.args) + + if self.args.spec is not None: + transformer_layer_spec = import_module(self.args.spec) + else: + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + self.args.num_experts, + self.args.moe_grouped_gemm, + self.args.qk_layernorm + ) + else: + transformer_layer_spec = get_gpt_layer_local_spec( + self.args.num_experts, + self.args.moe_grouped_gemm, + self.args.qk_layernorm + ) + + super().__init__( + config=self.config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=self.args.padded_vocab_size, + max_sequence_length=self.args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=self.args.fp16_lm_cross_entropy, + parallel_output=parallel_output, + share_embeddings_and_output_weights=not self.args.untie_embeddings_and_output_weights, + position_embedding_type=self.args.position_embedding_type, + rotary_percent=self.args.rotary_percent, + rotary_base=self.args.rotary_base + ) + + # Output + if post_process: + self.pooler_head = pooler_head(self.config, + score_dimensions=score_dimension) + self._pooler_head_key = 'pooler_head' + else: + self._pooler_head_key = None + + + def _language_model_forward(self, input_ids=None, position_ids=None, attention_mask=None, + inference_params=None): + if self.pre_process: + decoder_input = self.embedding(input_ids=input_ids, position_ids=position_ids) + else: + # TODO: CHECK! + # intermediate stage of pipeline + # decoder will get hidden_states from encoder.input_tensor? + decoder_input = None + + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + if self.position_embedding_type == 'rope': + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_params, self.decoder, decoder_input, self.config + ) + rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + + # Run decoder. + lm_output = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_params=inference_params, + rotary_pos_emb=rotary_pos_emb, + ) + return lm_output + + def forward(self, input_ids=None, position_ids=None, attention_mask=None, + labels=None, inference_params=None, + pooling_sequence_index=None, + inference_config=None): + lm_output = self._language_model_forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inference_params=inference_params + ) + + if self.post_process: + if inference_config is not None and "batch_encode" in inference_config: + print('GPTrewrad model batch encoding, give the transformers encodings') + if get_args().sequence_parallel: + lm_output = tensor_parallel.gather_from_sequence_parallel_region( + lm_output, # [s, b, h] + tensor_parallel_output_grad=False) + return lm_output + assert labels is None, "assume labels is None in reward model" + return self.pooler_head(lm_output, pooling_sequence_index) + # [b x score_dim] + return lm_output + + def load_state_dict(self, state_dict, strict=True):# pylint: disable=unused-argument + """Customized load.""" + # Directly utilize super().load_state_dict(state_dict, strict=True) causes exceptions. This is + # because if the base torch.nn.Module method load_state_dict is invoked, a strict key + # matching mechanism will be enforced across all parameters. It will become + # problematic when training a reward model derived from a SFT checkpoint, as + # these checkpoints typically lack the state_dict for pooler_head in the reward model. + incompatible_keys = super().load_state_dict(state_dict, strict=False) + if len(incompatible_keys.missing_keys) == 0 and len(incompatible_keys.unexpected_keys) == 0: + print_rank_0("load reward model pooler_head success") + return + elif self.post_process: + if all(missing_key.startswith(self._pooler_head_key) for missing_key in incompatible_keys.missing_keys): + print_rank_0("cannot load reward model pooler_head, init from random") + if all(unexpected_key.startswith("output_layer") for unexpected_key in incompatible_keys.unexpected_keys): + print_rank_0("neglect output_layer weight for reward model") + return + else: + error_msgs: List[str] = [] + if len(incompatible_keys.unexpected_keys) > 0: + error_msgs.insert( + 0, 'Unexpected key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in incompatible_keys.unexpected_keys))) + if len(incompatible_keys.missing_keys) > 0: + error_msgs.insert( + 0, 'Missing key(s) in state_dict: {}. '.format( + ', '.join(f'"{k}"' for k in incompatible_keys.missing_keys))) + raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( + self.__class__.__name__, "\n\t".join(error_msgs))) + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + model = MCoreRewardModel( + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + score_dimension=1, + ) + return model diff --git a/examples/megatron/models/mcore_value_model.py b/examples/megatron/models/mcore_value_model.py new file mode 100644 index 00000000..6669ced0 --- /dev/null +++ b/examples/megatron/models/mcore_value_model.py @@ -0,0 +1,154 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""value model""" + +import torch +import torch.distributed as dist +from megatron.core import mpu +from megatron.training.global_vars import get_tokenizer + +from .utils import get_advantages_and_returns +from .constants import select_actions_from_right_padded +from .mcore_reward_model import MCoreRewardModel + + +class MCoreValueModel(MCoreRewardModel): + """ValueModel for MCore""" + + def __init__(self, + parallel_output=True, + pre_process=True, + post_process=True, + stats=None, + buffer=None): + + super().__init__(parallel_output, pre_process, post_process) + self.tokenizer = get_tokenizer() + self.stats = stats + self.buffer = buffer + + # pylint: disable=arguments-differ + def forward(self, all_token_ids, all_position_ids, all_token_attention_mask, training_inputs=None, + inference_params=None, inference_config=None): + lm_output = self._language_model_forward( + input_ids=all_token_ids, + position_ids=all_position_ids, + attention_mask=all_token_attention_mask, + inference_params=inference_params + ) + + # note in middle pipeline, this all_token_logits is just a hidden + if self.post_process: + values = self.pooler_head(lm_output) + # is last pipeline stage, if inference return the last logits. if training, return the loss + return self.post_language_model_processing(training_inputs, values) + else: + return lm_output + + def post_language_model_processing(self, training_inputs, values): + + # is last pipeline stage, if inference return the last logits. if training, return the loss + + inference_only = training_inputs is None + + if inference_only: + # [s b] => [b s ] + return values.transpose(0, 1).contiguous() # [b,responses_szie] + else: + values_pred = values.transpose(0, 1).contiguous() # [b, all_token size] + + old_values = training_inputs['action_values'] + old_rewards = training_inputs['action_rewards'] + response_length = old_rewards.shape[1] + all_token_loss_mask = training_inputs["all_token_loss_mask"] + # For a proper positional encoding in case of left padding + advantages, returns = get_advantages_and_returns(self.args, + old_values, old_rewards, response_length) + + advantages_nonzero_for_log = advantages.view(-1)[advantages.view(-1).nonzero()] + + if self.args.log_interval > 0: + self.stats["value/advantages_mean"] = advantages_nonzero_for_log.mean() + self.stats["value/advantages_min"] = advantages_nonzero_for_log.min() + self.stats["value/advantages_max"] = advantages_nonzero_for_log.max() + self.stats["value/advantages_std"] = advantages_nonzero_for_log.std() + + returns_nonzero_for_log = returns.view(-1)[returns.view(-1).nonzero()] + + self.stats["value/returns_mean"] = returns_nonzero_for_log.mean() + self.stats["value/returns_min"] = returns_nonzero_for_log.min() + self.stats["value/returns_max"] = returns_nonzero_for_log.max() + + values_pred = values_pred[:, :-1] # remove the last token since we want the head(hidden(state)) + + values_pred = select_actions_from_right_padded(ts=values_pred, + action_starts=training_inputs["action_starts"] - 1, + response_size=response_length, + pad_value=0.0, dim=-1).contiguous() + mask = select_actions_from_right_padded(ts=all_token_loss_mask, + action_starts=training_inputs["action_starts"] - 1, + response_size=response_length, + pad_value=0, dim=-1).contiguous() + + values_clipped = torch.clamp( + values_pred, + old_values - self.args.cliprange_value, + old_values + self.args.cliprange_value, + ) + n = mask.sum() + + vf_loss1 = (values_pred - returns) ** 2 + vf_loss2 = (values_clipped - returns) ** 2 + + if self.args.clipped_value_only: + vf_loss = (0.5 * vf_loss2) * mask + vf_clipfrac = torch.sum((values_pred != values_clipped).float() * mask) / n + + else: + vf_loss = (0.5 * torch.max(vf_loss1, vf_loss2)) * mask + vf_clipfrac = torch.sum((vf_loss2 > vf_loss1).float() * mask) / n + + if self.args.log_interval > 0: + self.stats["value_model/vf_clipfrac"] = vf_clipfrac + + # forexplained var + gathered_returns = [ + torch.zeros(returns.size(0), self.args.max_position_embeddings).to(torch.cuda.current_device()) for + _ in + range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group()))] + padded_returns = torch.zeros(returns.size(0), self.args.max_position_embeddings).to( + torch.cuda.current_device()) + padded_returns[:, :returns.size(1)] = returns + + dist.all_gather(gathered_returns, padded_returns, group=mpu.get_data_parallel_group()) + + gathered_value_preds = [ + torch.zeros(returns.size(0), self.args.max_position_embeddings).to(torch.cuda.current_device()) for + _ in + range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group()))] + padded_value_pred = torch.zeros(returns.size(0), self.args.max_position_embeddings).to( + torch.cuda.current_device()) + padded_value_pred[:, :values_pred.size(1)] = values_pred + + dist.all_gather(gathered_value_preds, padded_value_pred, group=mpu.get_data_parallel_group()) + + # RL related stats: global + if torch.distributed.is_initialized(): + if torch.distributed.get_rank() == ( + torch.distributed.get_world_size() - 1): + self.buffer["value/returns"].extend(gathered_returns) + self.buffer["value/value_preds"].extend(gathered_value_preds) + + return vf_loss.contiguous() # [b,response_size] diff --git a/examples/megatron/models/old_policy_inference.py b/examples/megatron/models/old_policy_inference.py index f105fb33..5098137a 100644 --- a/examples/megatron/models/old_policy_inference.py +++ b/examples/megatron/models/old_policy_inference.py @@ -34,7 +34,8 @@ from chatlearn.utils import to_device from chatlearn.utils.megatron_utils import load_checkpoint from examples.megatron.data.prompt_dataset import PromptPipeline -from .policy_model import PolicyModel +from .policy_model import PolicyModel as LegacyPolicyModel +from .mcore_policy_model import MCorePolicyModel from .utils import tensorboard_scalar_dict, get_loss_mask, get_eos_id @@ -99,7 +100,19 @@ def model_provider(self, pre_process=True, post_process=True): """Build the model.""" print_rank_0('building GPT model ...') - model = PolicyModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process) + if self.args.use_legacy_models: + model = LegacyPolicyModel( + num_tokentypes=0, + parallel_output=False, + pre_process=pre_process, + post_process=post_process + ) + else: + model = MCorePolicyModel( + parallel_output=False, + pre_process=pre_process, + post_process=post_process + ) return model diff --git a/examples/megatron/models/old_value_inference.py b/examples/megatron/models/old_value_inference.py index acfd75d6..160c9148 100644 --- a/examples/megatron/models/old_value_inference.py +++ b/examples/megatron/models/old_value_inference.py @@ -23,6 +23,7 @@ from chatlearn import MegatronModule from chatlearn.utils import to_device from .value_model import ValueModel +from .mcore_value_model import MCoreValueModel from .constants import get_ltor_masks_and_position_ids_rlhf from .forward_step import forward_step_helper @@ -46,8 +47,12 @@ def model_provider(self, pre_process=True, post_process=True): """Build the model.""" print_rank_0('building GPT model ...') - model = ValueModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process, - stats=self.stats, buffer=self.buffer) + if self.args.use_legacy_models: + model = ValueModel(num_tokentypes=0, parallel_output=False, pre_process=pre_process, post_process=post_process, + stats=self.stats, buffer=self.buffer) + else: + model = MCoreValueModel(parallel_output=False, pre_process=pre_process, post_process=post_process, + stats=self.stats, buffer=self.buffer) return model diff --git a/examples/megatron/models/policy_model.py b/examples/megatron/models/policy_model.py index aae8da5a..5bdc622d 100644 --- a/examples/megatron/models/policy_model.py +++ b/examples/megatron/models/policy_model.py @@ -68,7 +68,7 @@ def forward(self, all_token_ids, all_position_ids, all_token_attention_mask, tra "parallel_output" in inference_config else self.parallel_output if inference_config is not None and "DPO_labels" in inference_config: - assert get_args().trainer_engine in [TrainerEngine.DPO.value, TrainerEngine.ONLINE_DPO.value] + assert get_args().trainer_engine in [TrainerEngine.DPO, TrainerEngine.ONLINE_DPO] if training_inputs is None: training_inputs = {} training_inputs["labels"] = inference_config["DPO_labels"] diff --git a/examples/megatron/models/policy_trainer.py b/examples/megatron/models/policy_trainer.py index d5a97e69..872b56c4 100644 --- a/examples/megatron/models/policy_trainer.py +++ b/examples/megatron/models/policy_trainer.py @@ -21,7 +21,10 @@ import torch.nn.functional as F from megatron.training import get_args -from megatron.training import get_num_microbatches +try: + from megatron.training import get_num_microbatches +except ImportError: + from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.training import get_tokenizer from megatron.training import print_rank_0 from megatron.core import mpu @@ -30,7 +33,8 @@ from megatron.training.utils import get_ltor_masks_and_position_ids from chatlearn.utils import to_device -from .policy_model import PolicyModel +from .policy_model import PolicyModel as LegacyPolicyModel +from .mcore_policy_model import MCorePolicyModel from .utils import training_log, get_eos_id, get_padding_length, pad_to_length from .base_trainer import BaseTrainer from .constants import TrainerEngine @@ -85,13 +89,21 @@ def model_provider(self, pre_process=True, post_process=True): """Build the model.""" print_rank_0('building GPT model ...') - model = PolicyModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - stats=self.stats - ) + if self.args.use_legacy_models: + model = LegacyPolicyModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + stats=self.stats + ) + else: + model = MCorePolicyModel( + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + stats=self.stats + ) if self.module_args.lora.enable_lora: from chatlearn.models.megatron.lora import convert_layer_to_lora # pylint: disable=import-outside-toplevel model = convert_layer_to_lora(model) diff --git a/examples/megatron/models/reference.py b/examples/megatron/models/reference.py index 59974c91..1f112f28 100644 --- a/examples/megatron/models/reference.py +++ b/examples/megatron/models/reference.py @@ -26,6 +26,7 @@ from chatlearn.utils.megatron_utils import load_checkpoint from examples.megatron.data.prompt_dataset import DPOPromptPipeline from .policy_model import PolicyModel +from .mcore_policy_model import MCorePolicyModel from .utils import get_eos_id, get_padding_length, pad_to_length from .constants import get_ltor_masks_and_position_ids_rlhf from .constants import TrainerEngine @@ -47,8 +48,12 @@ def model_provider(self, pre_process=True, post_process=True): print_rank_0("enable parallel output") else: self._parallel_output = False - model = PolicyModel(num_tokentypes=0, parallel_output=self._parallel_output, pre_process=pre_process, - post_process=post_process) + if args.use_legacy_models: + model = PolicyModel(num_tokentypes=0, parallel_output=self._parallel_output, pre_process=pre_process, + post_process=post_process) + else: + model = MCorePolicyModel(parallel_output=self._parallel_output, pre_process=pre_process, + post_process=post_process) return model diff --git a/examples/megatron/models/reward_inference.py b/examples/megatron/models/reward_inference.py index b672bbfa..ffb0ec4c 100644 --- a/examples/megatron/models/reward_inference.py +++ b/examples/megatron/models/reward_inference.py @@ -33,7 +33,8 @@ from chatlearn import MegatronModule from chatlearn.utils import to_device from chatlearn.utils.megatron_utils import load_checkpoint -from .reward_model import batch_padded_tokenize_data, model_provider +from .reward_model import RewardModel as LegacyRewardModel +from .mcore_reward_model import MCoreRewardModel from .utils import tensorboard_scalar_dict, get_eos_id from .constants import RunningMoments, get_running_stats, reset_running_stats from .forward_step import forward_step_helper @@ -60,7 +61,45 @@ def save_list_str(list_strs, iteration): for prompt, response in list_strs: k = {"query": prompt, "responses": [response], "iteration": iteration} res.append(k) - dump_jsonl_chinese(res, inference_output_path, mode="a") + dump_jsonl_chinese(res, inference_output_path, mode="w") + + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + args = get_args() + if args.use_legacy_models: + model = LegacyRewardModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + score_dimension=1, + ) + else: + model = MCoreRewardModel( + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + score_dimension=1, + ) + + return model + + +def batch_padded_tokenize_data(list_strs, tokenizer, max_length): + processed_dict = [preprocess(tokenizer.tokenize(line[0]), tokenizer.tokenize(line[1]), max_length, tokenizer) for + line in list_strs] + input_ids, input_lengths = [], [] + for item in processed_dict: + input_ids.append(torch.tensor(item['ids'])) + input_lengths.append(item['length']) + max_l = min(max(input_lengths), max_length) + input_ids = torch.stack(input_ids, dim=0)[:, :max_l] + input_eos_tok = torch.tensor(input_lengths) - 1 + + return input_ids, input_eos_tok class RewardInference(MegatronModule): diff --git a/examples/megatron/models/reward_model.py b/examples/megatron/models/reward_model.py index a88fd3a2..36dff59d 100644 --- a/examples/megatron/models/reward_model.py +++ b/examples/megatron/models/reward_model.py @@ -22,24 +22,9 @@ from megatron.legacy.model.module import MegatronModule from megatron.legacy.model.utils import get_linear_layer -from examples.megatron.data.reward_dataset import preprocess from .utils import has_config_in_args -def batch_padded_tokenize_data(list_strs, tokenizer, max_length): - processed_dict = [preprocess(tokenizer.tokenize(line[0]), tokenizer.tokenize(line[1]), max_length, tokenizer) for - line in list_strs] - input_ids, input_lengths = [], [] - for item in processed_dict: - input_ids.append(torch.tensor(item['ids'])) - input_lengths.append(item['length']) - max_l = min(max(input_lengths), max_length) - input_ids = torch.stack(input_ids, dim=0)[:, :max_l] - input_eos_tok = torch.tensor(input_lengths) - 1 - - return input_ids, input_eos_tok - - class LinearPooler(MegatronModule): """Pooler layer. diff --git a/examples/megatron/models/utils.py b/examples/megatron/models/utils.py index 6c9e73d4..fda8f92d 100644 --- a/examples/megatron/models/utils.py +++ b/examples/megatron/models/utils.py @@ -29,8 +29,12 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from megatron.training import print_rank_last, is_last_rank, get_num_microbatches, get_args, get_timers +from megatron.training import print_rank_last, is_last_rank, get_args, get_timers from megatron.core import mpu +try: + from megatron.training import get_num_microbatches +except ImportError: + from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.training.global_vars import get_tensorboard_writer from megatron.training.training import print_datetime from torchtyping import TensorType diff --git a/examples/megatron/models/value_trainer.py b/examples/megatron/models/value_trainer.py index fbe73d6f..999b9509 100644 --- a/examples/megatron/models/value_trainer.py +++ b/examples/megatron/models/value_trainer.py @@ -18,7 +18,10 @@ import torch from megatron.core import mpu -from megatron.training import get_num_microbatches +try: + from megatron.training import get_num_microbatches +except ImportError: + from megatron.core.num_microbatches_calculator import get_num_microbatches from megatron.training import get_timers from megatron.training import get_tokenizer from megatron.training import print_rank_0 @@ -27,7 +30,8 @@ from megatron.training.utils import calc_params_l2_norm from chatlearn.utils import to_device -from .value_model import ValueModel +from .value_model import ValueModel as LegacyValueModel +from .mcore_value_model import MCoreValueModel from .utils import tensorboard_scalar_dict, training_log, get_eos_id from .base_trainer import BaseTrainer from .constants import get_ltor_masks_and_position_ids_rlhf, select_actions_from_right_padded, pad_to_max_len @@ -40,14 +44,23 @@ def model_provider(self, pre_process=True, post_process=True): """Build the model.""" print_rank_0('building GPT model ...') - model = ValueModel( - num_tokentypes=0, - parallel_output=True, - pre_process=pre_process, - post_process=post_process, - stats=self.stats, - buffer=self.buffer - ) + if self.args.use_legacy_models: + model = LegacyValueModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + stats=self.stats, + buffer=self.buffer + ) + else: + model = MCoreValueModel( + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + stats=self.stats, + buffer=self.buffer + ) if self.module_args.lora.enable_lora: from chatlearn.models.megatron.lora import convert_layer_to_lora # pylint: disable=import-outside-toplevel model = convert_layer_to_lora(model) diff --git a/examples/megatron/scripts/base_env.sh b/examples/megatron/scripts/base_env.sh index c1f1d07a..aa1bccce 100644 --- a/examples/megatron/scripts/base_env.sh +++ b/examples/megatron/scripts/base_env.sh @@ -134,3 +134,11 @@ else echo "unsupported model_size ${model_size}, please set your own model config" exit 1 fi + +if [[ ${USE_LEGACY_MODELS:-"True"} = "False" ]]; then + export transformer_impl=transformer_engine + export use_legacy_models=False +else + export transformer_impl=local + export use_legacy_models=True +fi diff --git a/examples/megatron/scripts/convert_hf_to_megatron.sh b/examples/megatron/scripts/convert_hf_to_megatron.sh index f02bacfd..4110b309 100644 --- a/examples/megatron/scripts/convert_hf_to_megatron.sh +++ b/examples/megatron/scripts/convert_hf_to_megatron.sh @@ -5,7 +5,9 @@ set -x # model config # can be `gpt_llama' for GPT or Llama, or `mixtral' for Mixtral -model=${MODEL:-'gpt_llama'} +model=${MODEL:-'gpt_llama'} +# Whether to use legacy models, default: True +use_legacy_models=${USE_LEGACY_MODELS:-"True"} # parallel config tp=${TP:-1} @@ -28,12 +30,17 @@ START_TIME=$SECONDS if [[ ${model} == 'gpt_llama' ]]; then cd ${megatron} + if [[ ${use_legacy_models} = "False" ]]; then + saver="mcore" + else + saver="megatron" + fi python tools/checkpoint/convert.py \ --model-type GPT \ --loader llama_mistral \ --checkpoint-type hf \ --model-size ${model_size} \ - --saver megatron \ + --saver ${saver} \ --target-tensor-parallel-size ${tp} \ --target-pipeline-parallel-size ${pp} \ --load-dir ${load_dir} \ diff --git a/examples/megatron/scripts/convert_megatron_to_hf.sh b/examples/megatron/scripts/convert_megatron_to_hf.sh index a95326dc..e18c8cd4 100644 --- a/examples/megatron/scripts/convert_megatron_to_hf.sh +++ b/examples/megatron/scripts/convert_megatron_to_hf.sh @@ -1,6 +1,6 @@ #!/bin/bash # Convert LLaMA model from megatron format to huggingface format. -set -x +set -ex # config chatlearn=${CHATLEARN} @@ -11,21 +11,33 @@ vocab_path=${VOCAB_PATH} target_params_dtype=${target_params_dtype:-bf16} temp_path=${save_path}/temp +# Whether to use legacy models, default: True +use_legacy_models=${USE_LEGACY_MODELS:-"True"} + +if [[ ${use_legacy_models} = "False" ]]; then + ckpt_format="mcore" +else + ckpt_format="megatron" +fi + set +x # convert parallel strategy START_TIME=$SECONDS cd ${megatron} -python tools/checkpoint/convert.py \ - --model-type GPT \ - --loader megatron \ - --saver megatron \ - --target-tensor-parallel-size 1 \ - --target-pipeline-parallel-size 1 \ - --load-dir ${load_path} \ - --save-dir ${temp_path} \ - --megatron-path ${megatron} + +if [[ ! -d "${temp_path}" ]]; then + python tools/checkpoint/convert.py \ + --model-type GPT \ + --loader ${ckpt_format} \ + --saver "megatron" \ + --target-tensor-parallel-size 1 \ + --target-pipeline-parallel-size 1 \ + --load-dir ${load_path} \ + --save-dir ${temp_path} \ + --megatron-path ${megatron} +fi # convert to hf format cd ${chatlearn} diff --git a/examples/megatron/scripts/train_reward_llama.sh b/examples/megatron/scripts/train_reward_llama.sh index bf3fb8e4..877a9934 100644 --- a/examples/megatron/scripts/train_reward_llama.sh +++ b/examples/megatron/scripts/train_reward_llama.sh @@ -82,9 +82,18 @@ MODEL_ARGS=" --normalization RMSNorm \ --no-position-embedding \ --no-masked-softmax-fusion \ ---transformer-impl local \ --attention-softmax-in-fp32 " +use_legacy_models=${USE_LEGACY_MODELS:-"True"} + +if [[ ${use_legacy_models} = "False" ]]; then + MCORE_ARGS="--transformer-impl transformer_engine " +else + MCORE_ARGS=" + --use-legacy-models \ + --transformer-impl local " +fi + mkdir -p $CHECKPOINT_PATH log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log @@ -138,4 +147,4 @@ torchrun $DISTRIBUTED_ARGS \ --use-flash-attn \ --sequence-parallel \ --finetune \ - $MODEL_ARGS 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + $MODEL_ARGS $MCORE_ARGS 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/scripts/train_rlhf_llama.sh b/examples/megatron/scripts/train_rlhf_llama.sh index 89eebc9f..a6f417a4 100644 --- a/examples/megatron/scripts/train_rlhf_llama.sh +++ b/examples/megatron/scripts/train_rlhf_llama.sh @@ -34,6 +34,7 @@ export trainer_engine=rlhf [ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} [ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ [ -z "$sample_per_episode" ] && sample_per_episode=1024 +[ -z "$num_episode" ] && num_episode=100 [ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend output_dir=${output_dir}/${exp_name} @@ -105,6 +106,7 @@ num_gpu=${num_gpu} \ data_path=${DATASET_PATH} \ eval_data_path=${EVAL_DATASET_PATH} \ sample_per_episode=${sample_per_episode} \ +num_episode=${num_episode} \ python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/scripts/train_sft_llama.sh b/examples/megatron/scripts/train_sft_llama.sh index 97333d6d..5d9b2bac 100644 --- a/examples/megatron/scripts/train_sft_llama.sh +++ b/examples/megatron/scripts/train_sft_llama.sh @@ -24,6 +24,7 @@ DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHATLEARN} [ -z "$model_size" ] && export model_size=llama2-7B +[ -z "$TOKENIZER_TYPE" ] && export TOKENIZER_TYPE=Llama2Tokenizer if [ $model_size = llama2-7B ]; then NUM_LAYERS=32 @@ -48,6 +49,24 @@ elif [ $model_size = llama2-70B ]; then pp=4 mb=2 gbs=64 +elif [ $model_size = llama3-8B ]; then + NUM_LAYERS=32 + HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=32 + INTERMEDIATE_SIZE=14336 + tp=4 + pp=1 + TOKENIZER_TYPE=Llama3Tokenizer +elif [ $model_size = llama3-70B ]; then + NUM_LAYERS=80 + HIDDEN_SIZE=8192 + NUM_ATTN_HEADS=64 + INTERMEDIATE_SIZE=28672 + tp=8 + pp=4 + mb=2 + gbs=64 + TOKENIZER_TYPE=Llama3Tokenizer fi [ -z "$mb" ] && mb=8 @@ -67,13 +86,13 @@ dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp)) gbs=$(($gbs * $dp)) -[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/hh_sft_$(date +%F)_gpt_${model_size}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_mb${mb}_seqlen${seq_len} +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/hh_sft_$(date +%F)_gpt_${MODEL_SIZE}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_mb${mb}_seqlen${seq_len} mkdir -p $CHECKPOINT_PATH MODEL_ARGS=" --max-position-embeddings 4096 \ ---tokenizer-type Llama2Tokenizer \ +--tokenizer-type ${TOKENIZER_TYPE} \ --tokenizer-model ${TOKENIZER_MODEL} \ --exit-on-missing-checkpoint \ --use-checkpoint-args \ @@ -84,9 +103,22 @@ MODEL_ARGS=" --normalization RMSNorm \ --no-position-embedding \ --no-masked-softmax-fusion \ ---transformer-impl local \ --attention-softmax-in-fp32 " +use_legacy_models=${USE_LEGACY_MODELS:-"True"} + +if [[ ${use_legacy_models} = "False" ]]; then + MCORE_ARGS="--transformer-impl transformer_engine " +else + if [[ ${MODEL_SIZE} = "llama3-8B" || ${MODEL_SIZE} = "llama3-70B" ]]; then + echo "Llama3 models are not supported with USE_LEGACY_MODELS=True" + exit 1 + fi + MCORE_ARGS=" + --use-legacy-models \ + --transformer-impl local " +fi + log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log export CUDA_DEVICE_MAX_CONNECTIONS=1 @@ -138,4 +170,4 @@ torchrun $DISTRIBUTED_ARGS \ --sequence-parallel \ --finetune \ --distributed-timeout-minutes 60 \ - $MODEL_ARGS 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + $MODEL_ARGS $MCORE_ARGS 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/tests/run_policy_generation.sh b/examples/megatron/tests/run_policy_generation.sh index 11214269..ae55720d 100644 --- a/examples/megatron/tests/run_policy_generation.sh +++ b/examples/megatron/tests/run_policy_generation.sh @@ -8,6 +8,7 @@ set -x [ -z "$LOAD" ] && export LOAD=path-to-ckpt [ -z "$DATASET_PATH" ] && export DATASET_PATH=path-to-dataset-json [ -z "$model_size" ] && export model_size=llama2-13B +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend cd $CHATLEARN/examples/megatron From b4706221633881195f63cef132313dfd55ebed08 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Tue, 10 Sep 2024 17:42:16 +0800 Subject: [PATCH 02/26] Fix Megatron version --- docs/en/installation.md | 2 +- docs/zh/installation.md | 2 +- examples/megatron/scripts/convert_hf_to_megatron.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/en/installation.md b/docs/en/installation.md index d6358e09..ecb73f0b 100644 --- a/docs/en/installation.md +++ b/docs/en/installation.md @@ -21,5 +21,5 @@ git clone https://github.com/alibaba/ChatLearn.git ``` # Clone Megatron-LM git clone https://github.com/NVIDIA/Megatron-LM.git -git checkout 5161b1689 +git checkout core_r0.8.0 ``` diff --git a/docs/zh/installation.md b/docs/zh/installation.md index a2bcb9bc..494ab44f 100644 --- a/docs/zh/installation.md +++ b/docs/zh/installation.md @@ -21,5 +21,5 @@ git clone https://github.com/alibaba/ChatLearn.git ``` # 下载 Megatron-LM git clone https://github.com/NVIDIA/Megatron-LM.git -git checkout 5161b1689 +git checkout core_r0.8.0 ``` diff --git a/examples/megatron/scripts/convert_hf_to_megatron.sh b/examples/megatron/scripts/convert_hf_to_megatron.sh index 4110b309..4127c6f2 100644 --- a/examples/megatron/scripts/convert_hf_to_megatron.sh +++ b/examples/megatron/scripts/convert_hf_to_megatron.sh @@ -48,7 +48,7 @@ if [[ ${model} == 'gpt_llama' ]]; then --tokenizer-model ${tokenizer_model} elif [[ ${model} == 'mixtral' ]]; then # Mixtral can only be converted to mcore models. - # Require Megatron-LM commit id >= c7a1f82. + # Require Megatron-Core 0.8.0 or later. cd ${megatron} python tools/checkpoint/convert.py \ --model-type GPT \ From 41db596a489a8096142051feb0ed1c6822fad09b Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Tue, 10 Sep 2024 17:45:03 +0800 Subject: [PATCH 03/26] fix pylint --- chatlearn/models/vllm/vllm_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/models/vllm/vllm_model.py b/chatlearn/models/vllm/vllm_model.py index 61bb4e2a..62cfbebc 100644 --- a/chatlearn/models/vllm/vllm_model.py +++ b/chatlearn/models/vllm/vllm_model.py @@ -50,7 +50,7 @@ def load_weights(self): load_checkpoint(self, None, None) torch.distributed.barrier() - def load_state_dict(self, state_dict, strict=True, assign=False): + def load_state_dict(self, state_dict, strict=True, assign=False): # pylint: disable=unused-argument qwen_version = None if isinstance(self.model, LlamaForCausalLM): if self.model_args["use_legacy_models"]: From 636141ca5332408fe4bb90e67ae9b0c760537ae5 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 11 Sep 2024 01:40:05 +0000 Subject: [PATCH 04/26] fix missing import in reward_inference --- examples/megatron/models/reward_inference.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/megatron/models/reward_inference.py b/examples/megatron/models/reward_inference.py index ffb0ec4c..7023e93d 100644 --- a/examples/megatron/models/reward_inference.py +++ b/examples/megatron/models/reward_inference.py @@ -26,6 +26,7 @@ from megatron.training import get_args from megatron.training import get_model from megatron.training import get_tokenizer +from megatron.training import print_rank_0 from megatron.training.global_vars import get_tensorboard_writer from megatron.training.utils import get_ltor_masks_and_position_ids @@ -33,6 +34,7 @@ from chatlearn import MegatronModule from chatlearn.utils import to_device from chatlearn.utils.megatron_utils import load_checkpoint +from examples.megatron.data.reward_dataset import preprocess from .reward_model import RewardModel as LegacyRewardModel from .mcore_reward_model import MCoreRewardModel from .utils import tensorboard_scalar_dict, get_eos_id From 004211c98904b8b7be7e5404a8e9f28f2e91eb0d Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 11 Sep 2024 02:05:20 +0000 Subject: [PATCH 05/26] fix vllm --- chatlearn/models/vllm/vllm_model.py | 5 ++++- chatlearn/models/vllm_module.py | 5 ++++- examples/megatron/models/vllm_policy_inference.py | 1 + 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/chatlearn/models/vllm/vllm_model.py b/chatlearn/models/vllm/vllm_model.py index 62cfbebc..33597168 100644 --- a/chatlearn/models/vllm/vllm_model.py +++ b/chatlearn/models/vllm/vllm_model.py @@ -53,7 +53,10 @@ def load_weights(self): def load_state_dict(self, state_dict, strict=True, assign=False): # pylint: disable=unused-argument qwen_version = None if isinstance(self.model, LlamaForCausalLM): - if self.model_args["use_legacy_models"]: + use_legacy_models = self.model_args.get("use_legacy_models") + if use_legacy_models is None: + raise RuntimeError("Please specify use_legacy_models (True or False) for VLLMModel, but not None.") + if use_legacy_models: convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm else: convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm diff --git a/chatlearn/models/vllm_module.py b/chatlearn/models/vllm_module.py index ddde8e77..a0824b0b 100644 --- a/chatlearn/models/vllm_module.py +++ b/chatlearn/models/vllm_module.py @@ -536,7 +536,10 @@ def map_src_to_dst(self, src_names, num_src_pipeline_stage, src_pipe_stage): self._to_fix_qkv_ordering_func = fix_query_key_value_ordering sync_map = sync_map_cls(src_names, layer_offset, QwenVersion.v_2.value) elif isinstance(self.model.model, LlamaForCausalLM): - sync_map_cls = Megatron2LlamaSyncMap if self.model_args["use_legacy_models"] else MCore2LlamaSyncMap + use_legacy_models = self.model_args.get("use_legacy_models") + if use_legacy_models is None: + raise RuntimeError("Please specify use_legacy_models (True or False) for VLLMModule, but not None.") + sync_map_cls = Megatron2LlamaSyncMap if use_legacy_models else MCore2LlamaSyncMap from chatlearn.utils.vllm_utils import fix_qwen_query_key_value_ordering # pylint: disable=import-outside-toplevel self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering sync_map = sync_map_cls(src_names, layer_offset) diff --git a/examples/megatron/models/vllm_policy_inference.py b/examples/megatron/models/vllm_policy_inference.py index 70a1e0b9..642a7c67 100644 --- a/examples/megatron/models/vllm_policy_inference.py +++ b/examples/megatron/models/vllm_policy_inference.py @@ -15,6 +15,7 @@ """vllm policy inference""" import copy +import random import torch import torch.nn.functional as F From 1d6470799afbb459cfcc1922e052a4dd48c8d2ed Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 11 Sep 2024 03:42:41 +0000 Subject: [PATCH 06/26] refine hint --- examples/megatron/scripts/train_sft_llama.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/megatron/scripts/train_sft_llama.sh b/examples/megatron/scripts/train_sft_llama.sh index 5d9b2bac..1144e4ce 100644 --- a/examples/megatron/scripts/train_sft_llama.sh +++ b/examples/megatron/scripts/train_sft_llama.sh @@ -111,7 +111,7 @@ if [[ ${use_legacy_models} = "False" ]]; then MCORE_ARGS="--transformer-impl transformer_engine " else if [[ ${MODEL_SIZE} = "llama3-8B" || ${MODEL_SIZE} = "llama3-70B" ]]; then - echo "Llama3 models are not supported with USE_LEGACY_MODELS=True" + echo "Expect USE_LEGACY_MODELS to be False for Llama3 models, but got True." exit 1 fi MCORE_ARGS=" From b7367a22cea08ed39a8c8cdc1955941ce93295c8 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 11 Sep 2024 16:50:38 +0800 Subject: [PATCH 07/26] fix Makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 3df84528..e88b0cc5 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -PYTHON ?= python +PYTHON ?= python3 ADDITIONAL_DEPS ?= current_dir := $(dir $(abspath $(firstword $(MAKEFILE_LIST)))) @@ -15,7 +15,7 @@ test: $(LIB) .PHONY: lint lint: git config --global --add safe.directory $(current_dir) - pip install pylint==2.16.1 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com + @$(PYTHON) -m pip install pylint==2.16.1 -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com @$(PYTHON) -m pylint \ --rcfile=.pylintrc --output-format=parseable --jobs=8 \ $(shell git ls-tree --full-tree --name-only -r HEAD chatlearn | grep \.py$) \ From 7ef59ca72853576725bc3719d337dde47d6a5597 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Thu, 19 Sep 2024 17:16:20 +0800 Subject: [PATCH 08/26] Support Mixtral MoE --- chatlearn/models/base_module.py | 13 +- .../megatron/memory_manager/base_trainer.py | 69 ++- .../megatron/memory_manager/trainer_v1v2.py | 204 +++---- .../megatron/memory_manager/trainer_v3.py | 127 ++--- chatlearn/models/megatron_module.py | 25 +- chatlearn/runtime/decorator.py | 4 +- chatlearn/runtime/environment.py | 6 + chatlearn/runtime/executor.py | 27 +- chatlearn/runtime/parameter_sync.py | 74 ++- chatlearn/tools/convert.py | 79 +++ chatlearn/tools/loader_mcore_mixtral.py | 517 ++++++++++++++++++ chatlearn/tools/megatron_to_hf.py | 272 ++++++++- chatlearn/utils/arguments.py | 17 +- chatlearn/utils/megatron_import_helper.py | 2 + chatlearn/utils/megatron_utils.py | 9 +- examples/megatron/configs/mixtral/base.yaml | 83 +++ .../configs/mixtral/base_inference.yaml | 16 + .../megatron/configs/mixtral/base_train.yaml | 10 + examples/megatron/configs/mixtral/dpo.yaml | 46 ++ .../configs/mixtral/old_policy_inference.yaml | 15 + .../configs/mixtral/old_value_inference.yaml | 5 + .../megatron/configs/mixtral/online_dpo.yaml | 62 +++ .../configs/mixtral/policy_shared.yaml | 14 + .../megatron/configs/mixtral/ppo_policy.yaml | 39 ++ .../megatron/configs/mixtral/ppo_value.yaml | 30 + .../megatron/configs/mixtral/reference.yaml | 6 + .../configs/mixtral/reward_inference.yaml | 6 + .../configs/mixtral/reward_shared.yaml | 18 + examples/megatron/configs/mixtral/rlhf.yaml | 81 +++ examples/megatron/data/sft_dataset.py | 4 + examples/megatron/models/policy_model.py | 13 +- examples/megatron/models/reward_inference.py | 5 - examples/megatron/scripts/base_env.sh | 17 + .../scripts/convert_hf_to_megatron.sh | 6 +- .../scripts/convert_megatron_to_hf.sh | 76 ++- .../megatron/scripts/train_dpo_mixtral.sh | 53 ++ .../scripts/train_online_dpo_mixtral.sh | 78 +++ .../megatron/scripts/train_reward_mixtral.sh | 163 ++++++ .../megatron/scripts/train_rlhf_mixtral.sh | 85 +++ .../megatron/scripts/train_sft_mixtral.sh | 159 ++++++ .../megatron/tests/run_policy_generation.sh | 23 +- tests/run_tests.sh | 7 +- tests/test_data_dp_ep.py | 219 ++++++++ 43 files changed, 2524 insertions(+), 260 deletions(-) create mode 100644 chatlearn/tools/convert.py create mode 100644 chatlearn/tools/loader_mcore_mixtral.py create mode 100644 examples/megatron/configs/mixtral/base.yaml create mode 100644 examples/megatron/configs/mixtral/base_inference.yaml create mode 100644 examples/megatron/configs/mixtral/base_train.yaml create mode 100644 examples/megatron/configs/mixtral/dpo.yaml create mode 100644 examples/megatron/configs/mixtral/old_policy_inference.yaml create mode 100644 examples/megatron/configs/mixtral/old_value_inference.yaml create mode 100644 examples/megatron/configs/mixtral/online_dpo.yaml create mode 100644 examples/megatron/configs/mixtral/policy_shared.yaml create mode 100644 examples/megatron/configs/mixtral/ppo_policy.yaml create mode 100644 examples/megatron/configs/mixtral/ppo_value.yaml create mode 100644 examples/megatron/configs/mixtral/reference.yaml create mode 100644 examples/megatron/configs/mixtral/reward_inference.yaml create mode 100644 examples/megatron/configs/mixtral/reward_shared.yaml create mode 100644 examples/megatron/configs/mixtral/rlhf.yaml create mode 100644 examples/megatron/scripts/train_dpo_mixtral.sh create mode 100644 examples/megatron/scripts/train_online_dpo_mixtral.sh create mode 100644 examples/megatron/scripts/train_reward_mixtral.sh create mode 100644 examples/megatron/scripts/train_rlhf_mixtral.sh create mode 100644 examples/megatron/scripts/train_sft_mixtral.sh create mode 100644 tests/test_data_dp_ep.py diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index c56c6cb9..628c8699 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -70,7 +70,12 @@ def __init__(self, name, args=None, replica_id=0): self._is_colocate = False if self.total_gpu > 0: - self._num_gpu_per_replica = args.tensor_model_parallel_size * args.pipeline_model_parallel_size * args.zero_size + self._num_gpu_per_replica = ( + args.tensor_model_parallel_size + * args.pipeline_model_parallel_size + * args.expert_model_parallel_size + * args.zero_size + ) assert self._num_gpu_per_replica <= self.total_gpu assert self.total_gpu % self._num_gpu_per_replica == 0 if not self.trainable: @@ -857,6 +862,12 @@ def tensor_model_parallel_size(self): """ return self.module_args.tensor_model_parallel_size + def expert_model_parallel_size(self): + """ + :meta private: + """ + return self.module_args.expert_model_parallel_size + def num_layers(self): """ :meta private: diff --git a/chatlearn/models/megatron/memory_manager/base_trainer.py b/chatlearn/models/megatron/memory_manager/base_trainer.py index 05502f28..7ac921d7 100644 --- a/chatlearn/models/megatron/memory_manager/base_trainer.py +++ b/chatlearn/models/megatron/memory_manager/base_trainer.py @@ -19,6 +19,7 @@ import torch +from chatlearn.models.megatron.memory_manager.base import BaseMemoryManager from chatlearn.utils.flat_tensors import BucketizedFlatTensors from chatlearn.utils.logger import log_rank_0 from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version @@ -27,6 +28,7 @@ MixedPrecisionOptimizer, DistributedOptimizer, Float16OptimizerWithFloat16Params, + ChainedOptimizer, ) @@ -87,32 +89,53 @@ def __init__( self._use_distributed_optimizer = use_distributed_optimizer self._bucket_size_mb = bucket_size_mb + def sanity_check(single_optimizer): + assert isinstance( + single_optimizer, (MixedPrecisionOptimizer,) + ), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.' + + if self._use_distributed_optimizer: + assert isinstance(single_optimizer, DistributedOptimizer) + else: + log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params') + assert isinstance(single_optimizer, Float16OptimizerWithFloat16Params) + assert isinstance( model, (DistributedDataParallel,) ), f'Only support model type DistributedDataParallel, current type is {str(type(model))}.' - assert isinstance( - optimizer, (MixedPrecisionOptimizer,) - ), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.' - - # sanity check - if self._use_distributed_optimizer: - assert isinstance(optimizer, DistributedOptimizer) + if isinstance(optimizer, ChainedOptimizer): + for single_optimizer in optimizer.chained_optimizers: + sanity_check(single_optimizer) + self._is_chained_optimizer = True else: - log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params') - assert isinstance(optimizer, Float16OptimizerWithFloat16Params) + sanity_check(optimizer) + self._is_chained_optimizer = False self._main_weights_offloaded = False self._group_flat_main_weights: Optional[List[BucketizedFlatTensors]] = None self._megatron_version = get_megatron_version() - def _optimizer_load_state_bucket_into_device(self, device): + def get_optimizer_list(self): + if self._is_chained_optimizer: + optimizer_list = self._optimizer.chained_optimizers + else: + optimizer_list = [self._optimizer] + return optimizer_list + + def _optimizer_load_state_bucket_into_device(self, device, optimizer=None): """put the state bucket onto a device""" - state_dict = self._optimizer.optimizer.state_dict() - for tensors in state_dict['state'].values(): - keys = list(tensors.keys()) - for key in keys: - tensors[key] = tensors[key].to(device=device, non_blocking=True) + if optimizer is not None: + optimizer_list = [optimizer] + else: + optimizer_list = self.get_optimizer_list() + + for single_optimizer in optimizer_list: + state_dict = single_optimizer.optimizer.state_dict() + for tensors in state_dict['state'].values(): + keys = list(tensors.keys()) + for key in keys: + tensors[key] = tensors[key].to(device=device, non_blocking=True) # make sure the loading is finished before returning torch.cuda.synchronize() @@ -147,12 +170,16 @@ def offload_main_weights(self): return if self._group_flat_main_weights is None: - if self._use_distributed_optimizer: - self._group_flat_main_weights = self._flat_param_groups( - [self._optimizer.shard_fp32_from_float16_groups] - ) - else: - self._group_flat_main_weights = self._flat_param_groups([self._optimizer.fp32_from_float16_groups]) + self._group_flat_main_weights = [] + optimizer_list = self.get_optimizer_list() + + for optimizer in optimizer_list: + if self._use_distributed_optimizer: + self._group_flat_main_weights.extend(self._flat_param_groups( + [optimizer.shard_fp32_from_float16_groups] + )) + else: + self._group_flat_main_weights.extend(self._flat_param_groups([optimizer.fp32_from_float16_groups])) for flat_main_weights in self._group_flat_main_weights: flat_main_weights.copy_to_primary_store() diff --git a/chatlearn/models/megatron/memory_manager/trainer_v1v2.py b/chatlearn/models/megatron/memory_manager/trainer_v1v2.py index b390552f..e2ad9c75 100644 --- a/chatlearn/models/megatron/memory_manager/trainer_v1v2.py +++ b/chatlearn/models/megatron/memory_manager/trainer_v1v2.py @@ -90,27 +90,31 @@ def offload_weights(self): log_rank_0('Call offload_weights when already offloaded. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() if self._use_distributed_optimizer: - optimizer.shard_float16_groups.clear() - optimizer.shard_fp32_groups.clear() + for optimizer in optimizer_list: + optimizer.shard_float16_groups.clear() + optimizer.shard_fp32_groups.clear() if self._group_flat_weights is None: - if self._use_distributed_optimizer: - self._group_flat_weights = self._flat_param_groups( - [ - optimizer.model_float16_groups, - optimizer.model_fp32_groups, - ], - ) - else: - self._group_flat_weights = self._flat_param_groups( - [ - optimizer.float16_groups, - optimizer.fp32_from_fp32_groups, - ], - ) + self._group_flat_weights = [] + + for optimizer in optimizer_list: + if self._use_distributed_optimizer: + self._group_flat_weights.extend(self._flat_param_groups( + [ + optimizer.model_float16_groups, + optimizer.model_fp32_groups, + ], + )) + else: + self._group_flat_weights.extend(self._flat_param_groups( + [ + optimizer.float16_groups, + optimizer.fp32_from_fp32_groups, + ], + )) for flat_weights in self._group_flat_weights: flat_weights.copy_to_primary_store() @@ -127,7 +131,7 @@ def onload_weights(self): log_rank_0('Call onload_weights when already onloaded. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() for flat_weights in self._group_flat_weights: flat_weights.copy_to_gpu_buffer() @@ -151,55 +155,56 @@ def onload_weights(self): self._weights_offloaded = False return - shard_float16_groups = optimizer.shard_float16_groups - shard_fp32_groups = optimizer.shard_fp32_groups - param_gbuf_map = optimizer.model_param_gbuf_map - opt_group_ranges = optimizer.opt_group_ranges - model_gbuf_ranges = optimizer.model_gbuf_ranges - - # Rebuild shard_float16_groups and shard_fp32_groups, - # see Megatron DistributedOptimizer#build_model_and_main_param_groups. - for _, group_range in enumerate(opt_group_ranges): - shard_float16_params_this_group = [] - shard_fp32_params_this_group = [] - shard_float16_groups.append(shard_float16_params_this_group) - shard_fp32_groups.append(shard_fp32_params_this_group) - - for model_param in group_range["params"]: - assert model_param.requires_grad - if self._megatron_version == MegatronVersion.V2: - model_index, dtype, bucket_index = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index] - param_range = gbuf_range["param_map"][model_param]["param"] - elif self._megatron_version == MegatronVersion.V1: - model_index, dtype = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[model_index][dtype] - param_range = gbuf_range["param_map"][model_param]["param"] - - # fp16, bf16 params. - if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: - shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - - shard_float16_params_this_group.append(shard_model_param) - - # fp32 params. - elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1)[param_range.start : param_range.end] - shard_fp32_params_this_group.append(shard_model_param) - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - else: - raise TypeError( - 'Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(model_param.type()) - ) + for optimizer in optimizer_list: + shard_float16_groups = optimizer.shard_float16_groups + shard_fp32_groups = optimizer.shard_fp32_groups + param_gbuf_map = optimizer.model_param_gbuf_map + opt_group_ranges = optimizer.opt_group_ranges + model_gbuf_ranges = optimizer.model_gbuf_ranges + + # Rebuild shard_float16_groups and shard_fp32_groups, + # see Megatron DistributedOptimizer#build_model_and_main_param_groups. + for _, group_range in enumerate(opt_group_ranges): + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + + for model_param in group_range["params"]: + assert model_param.requires_grad + if self._megatron_version == MegatronVersion.V2: + model_index, dtype, bucket_index = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index] + param_range = gbuf_range["param_map"][model_param]["param"] + elif self._megatron_version == MegatronVersion.V1: + model_index, dtype = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: + shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + shard_float16_params_this_group.append(shard_model_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1)[param_range.start : param_range.end] + shard_fp32_params_this_group.append(shard_model_param) + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + else: + raise TypeError( + 'Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(model_param.type()) + ) self._weights_offloaded = False @@ -211,16 +216,17 @@ def free_grad_buffers(self): log_rank_0('Call free_grad_buffers when already freed. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() grad_dtype_to_params = self._grad_dtype_to_params - # This is necessary, but don't know why. - optimizer.zero_grad(True) + for optimizer in optimizer_list: + # This is necessary, but don't know why. + optimizer.zero_grad(True) - if self._use_distributed_optimizer: - # Release param_buffers because they share storage with grad_buffers. - # Note: param_buffers are only available in DistributedOptimizer. - optimizer.param_buffers.clear() + if self._use_distributed_optimizer: + # Release param_buffers because they share storage with grad_buffers. + # Note: param_buffers are only available in DistributedOptimizer. + optimizer.param_buffers.clear() # Release grad_buffers, including buckets in GradBuffer for newer Megatron version. # Release `main_grad` of parameters. @@ -252,7 +258,7 @@ def build_grad_buffers(self): log_rank_0('Call build_grad_buffers when already built. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() params_dtype = self._params_dtype grad_dtype_to_params = self._grad_dtype_to_params @@ -286,31 +292,33 @@ def build_grad_buffers(self): return # Re-allocate param_buffers, see Megatron DistributedOptimizer#__init__. - optimizer.param_buffers = [] - for _, _ in enumerate(optimizer.models): - current_param_buffers = {} - for dtype, grad_buffer in self.get_grad_buffers().items(): - current_param_buffers[dtype] = [] - if self._megatron_version == MegatronVersion.V2: - for bucket in grad_buffer.buckets: + # pylint: disable=too-many-nested-blocks + for optimizer in optimizer_list: + optimizer.param_buffers = [] + for _, _ in enumerate(optimizer.models): + current_param_buffers = {} + for dtype, grad_buffer in self.get_grad_buffers().items(): + current_param_buffers[dtype] = [] + if self._megatron_version == MegatronVersion.V2: + for bucket in grad_buffer.buckets: + try: + storage = bucket.data.storage()._untyped() + # pylint: disable-next=bare-except + except: + storage = bucket.data.storage().untyped() + + param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage) + param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()] + current_param_buffers[dtype].append(param_buffer) + elif self._megatron_version == MegatronVersion.V1: try: - storage = bucket.data.storage()._untyped() + storage = grad_buffer.data.storage()._untyped() # pylint: disable-next=bare-except except: - storage = bucket.data.storage().untyped() - - param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage) - param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()] - current_param_buffers[dtype].append(param_buffer) - elif self._megatron_version == MegatronVersion.V1: - try: - storage = grad_buffer.data.storage()._untyped() - # pylint: disable-next=bare-except - except: - storage = grad_buffer.data.storage().untyped() - param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage) - param_buffer = param_buffer[: grad_buffer.numel_padded] - current_param_buffers[dtype] = param_buffer - optimizer.param_buffers.append(current_param_buffers) + storage = grad_buffer.data.storage().untyped() + param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage) + param_buffer = param_buffer[: grad_buffer.numel_padded] + current_param_buffers[dtype] = param_buffer + optimizer.param_buffers.append(current_param_buffers) self._grad_buffers_freed = False diff --git a/chatlearn/models/megatron/memory_manager/trainer_v3.py b/chatlearn/models/megatron/memory_manager/trainer_v3.py index 7d640b29..e700aa45 100644 --- a/chatlearn/models/megatron/memory_manager/trainer_v3.py +++ b/chatlearn/models/megatron/memory_manager/trainer_v3.py @@ -81,16 +81,15 @@ def offload_weights(self): log_rank_0('Call offload_weights when already offloaded. Ignore it.') return - optimizer = self._optimizer - - # TODO(jiqi): support expert parallel params + optimizer_list = self.get_optimizer_list() # In the V3 version, when distributed optimizer is used, parameter data are managed together with # gradients in buffers. if self._use_distributed_optimizer: - optimizer.shard_float16_groups.clear() - optimizer.shard_fp32_groups.clear() - optimizer.pbuf_view_items.clear() + for optimizer in optimizer_list: + optimizer.shard_float16_groups.clear() + optimizer.shard_fp32_groups.clear() + optimizer.pbuf_view_items.clear() if self._group_flat_weights is None: self._group_flat_weights = [] @@ -112,12 +111,14 @@ def offload_weights(self): bucket.param_data = None else: if self._group_flat_weights is None: - self._group_flat_weights = self._flat_param_groups( - [ - optimizer.float16_groups, - optimizer.fp32_from_fp32_groups, - ], - ) + self._group_flat_weights = [] + for optimizer in optimizer_list: + self._group_flat_weights.extend(self._flat_param_groups( + [ + optimizer.float16_groups, + optimizer.fp32_from_fp32_groups, + ], + )) # Offload param_data of buffers for flat_weights in self._group_flat_weights: @@ -135,7 +136,7 @@ def onload_weights(self): log_rank_0('Call onload_weights when already onloaded. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() # Onload param_data of buffers for flat_weights in self._group_flat_weights: @@ -175,52 +176,53 @@ def onload_weights(self): self._weights_offloaded = False return - optimizer.pbuf_view_items = optimizer._get_model_param_buffer_dp_views() - - shard_float16_groups = optimizer.shard_float16_groups - shard_fp32_groups = optimizer.shard_fp32_groups - param_gbuf_map = optimizer.model_param_gbuf_map - opt_group_ranges = optimizer.opt_group_ranges - model_gbuf_ranges = optimizer.gbuf_ranges - - # Rebuild shard_float16_groups and shard_fp32_groups, - # see Megatron DistributedOptimizer#build_model_and_main_param_groups. - for _, group_range in enumerate(opt_group_ranges): - shard_float16_params_this_group = [] - shard_fp32_params_this_group = [] - shard_float16_groups.append(shard_float16_params_this_group) - shard_fp32_groups.append(shard_fp32_params_this_group) - - for model_param in group_range["params"]: - assert model_param.requires_grad - gbuf_index, dtype, bucket_index = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[gbuf_index][dtype][bucket_index] - param_range = gbuf_range["param_map"][model_param]["param"] - - # fp16, bf16 params. - if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: - shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - - shard_float16_params_this_group.append(shard_model_param) - - # fp32 params. - elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1)[param_range.start : param_range.end] - shard_fp32_params_this_group.append(shard_model_param) - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - else: - raise TypeError( - 'Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(model_param.type()) - ) + for optimizer in optimizer_list: + optimizer.pbuf_view_items = optimizer._get_model_param_buffer_dp_views() + + shard_float16_groups = optimizer.shard_float16_groups + shard_fp32_groups = optimizer.shard_fp32_groups + param_gbuf_map = optimizer.model_param_gbuf_map + opt_group_ranges = optimizer.opt_group_ranges + model_gbuf_ranges = optimizer.gbuf_ranges + + # Rebuild shard_float16_groups and shard_fp32_groups, + # see Megatron DistributedOptimizer#build_model_and_main_param_groups. + for _, group_range in enumerate(opt_group_ranges): + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + + for model_param in group_range["params"]: + assert model_param.requires_grad + gbuf_index, dtype, bucket_index = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[gbuf_index][dtype][bucket_index] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: + shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + shard_float16_params_this_group.append(shard_model_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1)[param_range.start : param_range.end] + shard_fp32_params_this_group.append(shard_model_param) + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + else: + raise TypeError( + 'Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(model_param.type()) + ) self._weights_offloaded = False @@ -232,10 +234,11 @@ def free_grad_buffers(self): log_rank_0('Call free_grad_buffers when already freed. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() - # This is necessary, but don't know why. - optimizer.zero_grad(True) + for optimizer in optimizer_list: + # This is necessary, but don't know why. + optimizer.zero_grad(True) # Remove references from params for p, buffer in self._model.param_to_buffer.items(): diff --git a/chatlearn/models/megatron_module.py b/chatlearn/models/megatron_module.py index 13a65704..6adc165d 100644 --- a/chatlearn/models/megatron_module.py +++ b/chatlearn/models/megatron_module.py @@ -14,6 +14,7 @@ # ============================================================================== """Megatron module""" import inspect +import torch.distributed as dist try: from chatlearn.utils.megatron_import_helper import get_args @@ -138,6 +139,14 @@ def tensor_model_parallel_size(self): """ return self.megatron_args.tensor_model_parallel_size + def expert_model_parallel_size(self): + """ + get expert_model_parallel_size + + :meta private: + """ + return self.megatron_args.expert_model_parallel_size + @property def data_parallel_size(self): """ @@ -164,6 +173,12 @@ def tensor_parallel_rank(self): """ return mpu.get_tensor_model_parallel_rank() + def expert_parallel_rank(self): + """ + :meta private: + """ + return mpu.get_expert_model_parallel_rank() + def num_layers(self): """ :meta private: @@ -205,8 +220,14 @@ def get_local_param_ranks(self): """ :meta private: """ - data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) - return data_parallel_global_ranks, mpu.get_data_parallel_rank() + if self.expert_model_parallel_size() == 1: + data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS) + return data_parallel_global_ranks, mpu.get_data_parallel_rank() + else: + # Get data parallel modulo expert parallel ranks + data_modulo_expert_parallel_group = mpu.get_data_modulo_expert_parallel_group() + data_modulo_expert_parallel_ranks = dist.get_process_group_ranks(data_modulo_expert_parallel_group) + return data_modulo_expert_parallel_ranks, mpu.get_data_modulo_expert_parallel_rank() def save_checkpoint(self, iteration): """ diff --git a/chatlearn/runtime/decorator.py b/chatlearn/runtime/decorator.py index 9b3ccb22..605914fe 100644 --- a/chatlearn/runtime/decorator.py +++ b/chatlearn/runtime/decorator.py @@ -156,7 +156,7 @@ def get_kwarg(key): self._iteration += 1 ret = utils.to_device('cpu', ret) results.append(ret) - # for model with DP, we need to return results from all ranks + # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1: final_results = concat_along_batch(results) @@ -167,7 +167,7 @@ def get_kwarg(key): ret = utils.to_device('cpu', ret) self._iteration += 1 final_results = None - # for model with DP, we need to return results from all ranks + # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1: final_results = ret diff --git a/chatlearn/runtime/environment.py b/chatlearn/runtime/environment.py index f1ac93f0..8f142233 100644 --- a/chatlearn/runtime/environment.py +++ b/chatlearn/runtime/environment.py @@ -108,6 +108,12 @@ def num_iteration(self): if self.models[0].module_args.zero_size > 1: assert self.batch_per_episode % self.models[0].module_args.zero_size == 0 return self.batch_per_episode // self.models[0].module_args.zero_size + elif self.models[0].module_args.expert_model_parallel_size > 1: + assert self.batch_per_episode % self.models[0].module_args.expert_model_parallel_size == 0, ( + f"batch per episode ({self.batch_per_episode}) must be divisible by expert model parallel " + f"size ({self.models[0].module_args.expert_model_parallel_size})." + ) + return self.batch_per_episode // self.models[0].module_args.expert_model_parallel_size else: return self.batch_per_episode diff --git a/chatlearn/runtime/executor.py b/chatlearn/runtime/executor.py index 43d6d8ab..a80cb4b0 100644 --- a/chatlearn/runtime/executor.py +++ b/chatlearn/runtime/executor.py @@ -162,7 +162,7 @@ def execute_offload(self, model_node): refs = model.offload() future.wait(refs) - def generate_step_one_model_internal(self, model, in_queue, step_num, replica=None, func_name="forward_step", to_empty_cache=None, + def generate_step_one_model_internal(self, model, in_queue, step_num, replica, func_name="forward_step", to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None): """ Args: @@ -173,8 +173,6 @@ def generate_step_one_model_internal(self, model, in_queue, step_num, replica=No func_name: str to_empty_cache: None or boolean """ - if replica is None: - replica = self._get_model(model) def get_next_data(): if isinstance(in_queue, list): @@ -225,14 +223,27 @@ def generate_step_one_model(self, model, in_queue, out_queue, step_num, func_nam to_empty_cache: None or boolean """ # output is a list of tuple, each tuple is (remote_refs, mb) - output = self.generate_step_one_model_internal(model, in_queue, step_num, None, func_name, to_empty_cache, + replica = self._get_model(model) + output = self.generate_step_one_model_internal(model, in_queue, step_num, replica, func_name, to_empty_cache, is_eval, to_onload, to_offload) - # If tp > 1 or pp > 1 for current model, its `output` will be a list whose - # length is the number of Actors. In this case, all members in the list - # are the same, and we choose output[-1] to put into out_queue. if model.module_args.zero_size == 1: - result = [output[-1]] + # If (tp > 1 or pp > 1) and ep = 1 for current model, its `output` will be a list whose + # length is the number of Actors. In this case, all members in the list + # are the same, and we choose output[-1] to put into out_queue. + # If (tp > 1 or pp > 1) and ep > 1, we choose last output for each dp rank to put into + # out_queue. + if model.module_args.expert_model_parallel_size == 1: + result = [output[-1]] + else: + num_dp_rank = len(replica.dp_rank_to_actors) + num_output = len(output) + assert num_output % num_dp_rank == 0, ( + f"The number of outputs ({num_output}) must be divisible by " + f"the number of dp_ranks ({num_dp_rank}) in a replica." + ) + interval = num_output // num_dp_rank + result = [output[i] for i in range(interval - 1, num_output, interval)] else: result = output if isinstance(out_queue, list): diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index e274178b..75c6225b 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -56,11 +56,14 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): self._num_dst_pipeline_stage = None self._num_src_tensor_parallel = None self._num_dst_tensor_parallel = None + self._num_src_expert_parallel = None + self._num_dst_expert_parallel = None self._dst_prefix = None self._src_prefix = None self._send_recv_param_names = {} self._actor2pipe = {} self._actor2tp = {} + self._actor2ep = {} self._actor2dp = {} self._validate_params = {} self._comm_type = get_args().runtime_args.param_sync_comm_type @@ -131,6 +134,18 @@ def num_dst_tensor_parallel(self): self._num_dst_tensor_parallel = future.get(self.dst_model.replicas[0].all_actors[0].tensor_model_parallel_size.remote()) return self._num_dst_tensor_parallel + @property + def num_src_expert_parallel(self): + if self._num_src_expert_parallel is None: + self._num_src_expert_parallel = future.get(self.src_model.replicas[0].all_actors[0].expert_model_parallel_size.remote()) + return self._num_src_expert_parallel + + @property + def num_dst_expert_parallel(self): + if self._num_dst_expert_parallel is None: + self._num_dst_expert_parallel = future.get(self.dst_model.replicas[0].all_actors[0].expert_model_parallel_size.remote()) + return self._num_dst_expert_parallel + def setup_collective_group(self): refs = [] # we put src_model first, so we don't need to change the rank of training model @@ -165,57 +180,69 @@ def add_recv_actor(self, src_rank, dst_rank): dst_actor = self.dst_model.get_actor(dst_rank) self.actor2rank[dst_actor] = dst_rank - src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") - dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") + src_gpu = future.get(src_actor.get_visible_gpus.remote()) + dst_gpu = future.get(dst_actor.get_visible_gpus.remote()) src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) dst_pp_rank = self.get_actor_pipe_rank(dst_actor) - logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + \ - f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + \ - f"from tp rank {src_tp_rank} to {dst_tp_rank}") + src_ep_rank = self.get_actor_ep_rank(src_actor) + dst_ep_rank = self.get_actor_ep_rank(dst_actor) + logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + + f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + + f"from tp rank {src_tp_rank} to {dst_tp_rank}, " + + f"from ep rank {src_ep_rank} to {dst_ep_rank}.") assert src_tp_rank == dst_tp_rank, f"src_tp_rank {src_tp_rank} should be same as dst_tp_rank {dst_tp_rank}" + assert src_ep_rank == dst_ep_rank, f"src_ep_rank {src_ep_rank} should be same as dst_ep_rank {dst_ep_rank}" self.send_recv_actor_mappings[src_actor].append(dst_actor) self.recv_send_actor_mappings[dst_actor].append(src_actor) def build_rank_mapping(self): # setup rank mapping for src parameter and dst parameter # get rank for one src_model, without model replicas - dst_ranks = self.dst_model.all_ranks + dst_dp_ranks = self.dst_model.all_ranks local_src_ranks = future.get(self.src_model.replicas[0].get_local_param_ranks()) - if local_src_ranks[0] is None or dst_ranks is None: + if local_src_ranks[0] is None or dst_dp_ranks is None: if self._debug: logger.warning( - f"DEBUG MODE! src_ranks {local_src_ranks} or dst_ranks: {dst_ranks} is None, make sure they have values in real application.") + f"DEBUG MODE! src_dp_ranks {local_src_ranks} or dst_dp_ranks: {dst_dp_ranks} is None, " + "make sure they have values in real application.") return else: - raise Exception(f"src_ranks {local_src_ranks} or dst_ranks {dst_ranks} should not be None") + raise Exception(f"src_dp_ranks {local_src_ranks} or dst_dp_ranks {dst_dp_ranks} should not be None") dp_rank_to_ranks = defaultdict(list) for local_ranks, dp_rank in local_src_ranks: dp_rank_to_ranks[dp_rank].append(local_ranks[dp_rank]) - src_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())] + src_dp_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())] - assert len(src_ranks[0]) % len(dst_ranks[0]) == 0, \ - f"src training model ranks should be times of dst ranks, but got {len(src_ranks[0])} and {len(dst_ranks[0])}" + assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ + f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: - replica_rank_iter = cycle(reversed(src_ranks)) + replica_rank_iter = cycle(reversed(src_dp_ranks)) else: - replica_rank_iter = cycle(iter(src_ranks)) - logger.debug(f"src_ranks: {src_ranks}") - logger.debug(f"dst_ranks: {dst_ranks}") + replica_rank_iter = cycle(iter(src_dp_ranks)) + logger.debug(f"src_dp_ranks: {src_dp_ranks}") + logger.debug(f"dst_dp_ranks: {dst_dp_ranks}") assert self.num_src_tensor_parallel == self.num_dst_tensor_parallel, \ "currently we require the tensor_model_parallel_size to be the same between " + \ f"src model {self.src_model.name}(TP={self.num_src_tensor_parallel}) and " + \ f"dst model {self.dst_model.name}(TP={self.num_dst_tensor_parallel})" + assert self.num_src_expert_parallel == self.num_dst_expert_parallel, \ + "currently we require the expert_model_parallel_size to be the same between " + \ + f"src model {self.src_model.name}(EP={self.num_src_expert_parallel}) and " + \ + f"dst model {self.dst_model.name}(EP={self.num_dst_expert_parallel})" assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0 - def split_ranks_by_tp_size(ranks, tp_size): - return [ranks[i:i + tp_size] for i in range(0, len(ranks), tp_size)] + def split_ranks_by_tp_and_ep_size(ranks, + tp_size : int = 1, + ep_size : int = 1): + tp_size_mul_ep_size = tp_size * ep_size + return [ranks[i:i + tp_size_mul_ep_size] for i in range(0, len(ranks), tp_size_mul_ep_size)] - for dst_replica_ranks in dst_ranks: + for dst_replica_ranks in dst_dp_ranks: src_replica_ranks = next(replica_rank_iter) - src_replica_ranks_group = split_ranks_by_tp_size(src_replica_ranks, self.num_src_tensor_parallel) - dst_replica_ranks_group = split_ranks_by_tp_size(dst_replica_ranks, self.num_src_tensor_parallel) + src_replica_ranks_group = split_ranks_by_tp_and_ep_size(src_replica_ranks, self.num_src_tensor_parallel, self.num_src_expert_parallel) + dst_replica_ranks_group = split_ranks_by_tp_and_ep_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel) pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage for i, src_tp_group in enumerate(src_replica_ranks_group): j = i // pipe_map_interval @@ -330,6 +357,11 @@ def inner_func(): return future.get(actor.tensor_parallel_rank.remote()) return utils.get_or_cache(self._actor2tp, actor, inner_func) + def get_actor_ep_rank(self, actor): + def inner_func(): + return future.get(actor.expert_parallel_rank.remote()) + return utils.get_or_cache(self._actor2ep, actor, inner_func) + def get_actor_dp_rank(self, actor): def inner_func(): return future.get(actor.get_data_parallel_rank.remote()) diff --git a/chatlearn/tools/convert.py b/chatlearn/tools/convert.py new file mode 100644 index 00000000..0974a0c6 --- /dev/null +++ b/chatlearn/tools/convert.py @@ -0,0 +1,79 @@ +# Code below is modified from NVIDIA Megatron-LM +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Convertion script""" + +import argparse +import importlib +import sys +import torch.multiprocessing as mp + + +# Code below is copied from Megatron-LM core v0.8.0 +def load_plugin(plugin_type, name): + module_name = f"{plugin_type}_{name}" + try: + plugin = importlib.import_module(module_name) + except ModuleNotFoundError as e1: + print(e1) + module_name = name + try: + plugin = importlib.import_module(module_name) + except ModuleNotFoundError as e2: + print(e2) + sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.") + + if not hasattr(plugin, 'add_arguments'): + sys.exit(f"{module_name} module is not a plugin. Exiting.") + + print(f"Loaded {module_name} as the {plugin_type}.") + return plugin + +def main(): + parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", + allow_abbrev=False, conflict_handler='resolve') + + parser.add_argument('--model-type', type=str, required=True, + choices=['GPT', 'BERT'], + help='Type of the model') + parser.add_argument('--loader', type=str, default='megatron', + help='Module name to load checkpoint, should be on python path') + parser.add_argument('--loader-prefix', type=str, default='loader', + help='Prefix import path for loader') + parser.add_argument('--saver', type=str, default='megatron', + help='Module name to save checkpoint, should be on python path') + parser.add_argument('--saver-prefix', type=str, default='saver', + help='Prefix import path for saver') + parser.add_argument('--load-dir', type=str, required=True, + help='Directory to load model checkpoint from') + parser.add_argument('--save-dir', type=str, required=True, + help='Directory to save model checkpoint to') + parser.add_argument('--max-queue-size', type=int, default=50, + help='Maximum number of tensors in the queue') + parser.add_argument('--no-checking', action='store_false', + help='Do not perform checking on the name and ordering of weights', + dest='checking') + + known_args, _ = parser.parse_known_args() + loader = load_plugin(known_args.loader_prefix, known_args.loader) + saver = load_plugin(known_args.saver_prefix, known_args.saver) + + loader.add_arguments(parser) + saver.add_arguments(parser) + + args = parser.parse_args() + + queue = mp.Queue(maxsize=args.max_queue_size) + + print("Starting saver...") + saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args)) + saver_proc.start() + + print("Starting loader...") + loader.load_checkpoint(queue, args) + + print("Waiting for saver to complete...") + saver_proc.join() + + +if __name__ == '__main__': + main() diff --git a/chatlearn/tools/loader_mcore_mixtral.py b/chatlearn/tools/loader_mcore_mixtral.py new file mode 100644 index 00000000..1544271d --- /dev/null +++ b/chatlearn/tools/loader_mcore_mixtral.py @@ -0,0 +1,517 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""load mcore mixtral model""" + +import json +import os +import sys +import types +import torch + +from utils import get_mcore_transformer_block_key, print_memory_usage + + +def add_arguments(parser): + group = parser.add_argument_group(title='Megatron loader') + + group.add_argument('--true-vocab-size', type=int, default=None, + help='original size of vocab, if specified will trim padding from embedding table.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file. If specified will use this to get vocab size and ' + 'trim padding from the embedding table.') + group.add_argument('--megatron-path', type=str, default=None, + help='Base directory of Megatron repository') + group.add_argument('--position-embedding-type', + type=str, + default='learned_absolute', + choices=['learned_absolute', 'rope'], + help='Position embedding type.') + group.add_argument('--loader-transformer-impl', default='transformer_engine', + choices=['local', 'transformer_engine'], + help='Which Transformer implementation to use.') + + +def _load_checkpoint(queue, args): + + # Search in directory above this + sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + # pylint: disable=import-outside-toplevel + try: + from megatron.training.arguments import parse_args, validate_args + from megatron.training.global_vars import set_global_variables + from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint # pylint: disable=redefined-outer-name + from megatron.legacy.model import module + from megatron.core import mpu + from megatron.core.enums import ModelType + from megatron.legacy import fused_kernels + except ModuleNotFoundError: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + queue.put("exit") + sys.exit(1) + + # We want all arguments to come from us + sys.argv = ['script.py', + '--no-masked-softmax-fusion', + '--no-bias-gelu-fusion', + '--no-bias-dropout-fusion', + '--no-async-tensor-model-parallel-allreduce', + '--use-cpu-initialization', + '--micro-batch-size', '1', + '--no-load-optim', + '--no-load-rng', + '--no-save-optim', + '--no-save-rng', + '--no-initialization', + '--mock-data', # To pass the "blend data checks" in arguments.py + '--load', args.load_dir, + '--position-embedding-type', args.position_embedding_type, + ] + + margs = parse_args() + margs, checkpoint_args = load_args_from_checkpoint(margs, exit_on_missing_checkpoint=True) + + # Arguments do sanity checks on the world size, but we don't care, + # so trick it into thinking we are plenty of processes + margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size * margs.expert_model_parallel_size + + # Explicitly copy data types from checkpoint. + margs.fp16 = checkpoint_args.fp16 + margs.bf16 = checkpoint_args.bf16 + + margs.use_legacy_models = False + margs.transformer_impl = args.loader_transformer_impl + if checkpoint_args.expert_model_parallel_size > 1: + margs.expert_model_parallel_size = checkpoint_args.expert_model_parallel_size + margs.num_experts = checkpoint_args.num_experts + + # Validate margs. + margs = validate_args(margs) + + def check_for_arg(arg_name, default=None): + if getattr(margs, arg_name, None) is None: + if default is not None: + setattr(margs, arg_name, default) + else: + print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") + print(f"Arguments: {margs}") + queue.put("exit") + sys.exit(1) + + check_for_arg('tensor_model_parallel_size') + check_for_arg('pipeline_model_parallel_size') + check_for_arg('expert_model_parallel_size') + check_for_arg('num_layers') + check_for_arg('hidden_size') + check_for_arg('seq_length') + check_for_arg('num_attention_heads') + check_for_arg('max_position_embeddings') + check_for_arg('position_embedding_type') + check_for_arg('tokenizer_type') + check_for_arg('iteration') + check_for_arg('bert_binary_head') + check_for_arg('disable_bias_linear', False) + check_for_arg('params_dtype') + check_for_arg('swiglu', False) + if checkpoint_args.expert_model_parallel_size > 1: + check_for_arg('num_experts') + + # Determine how to make our models + if args.model_type == 'GPT': + from pretrain_gpt import model_provider + margs.model_type = ModelType.encoder_or_decoder + elif args.model_type == 'BERT': + from pretrain_bert import model_provider + margs.model_type = ModelType.encoder_or_decoder + else: + raise Exception(f'unrecognized model type: {args.model_type}') + + # supress warning about torch.distributed not being initialized + module.MegatronModule.embedding_warning_printed = True + + consumed_train_samples = None + consumed_valid_samples = None + def get_models(tp_size, ep_size, dtype): + nonlocal consumed_train_samples + nonlocal consumed_valid_samples + model_array_len = margs.virtual_pipeline_model_parallel_size + if model_array_len is None: + model_array_len = 1 + models = [[[] for _ in range(ep_size)] for _ in range(model_array_len)] + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + for ep_rank in range(ep_size): + mpu.set_expert_model_parallel_rank(ep_rank) + for tp_rank in range(tp_size): + mpu.set_tensor_model_parallel_rank(tp_rank) + if margs.virtual_pipeline_model_parallel_size is not None: + model_ = [] + for i in range(margs.virtual_pipeline_model_parallel_size): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + this_model = model_provider( + pre_process=pre_process, + post_process=post_process + ).to(dtype) + model_.append(this_model) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + model_ = [model_provider(pre_process, post_process).to(dtype)] + margs.consumed_train_samples = 0 + margs.consumed_valid_samples = 0 + margs.exit_on_missing_checkpoint = True + load_checkpoint(model_, None, None, strict=False) + + if consumed_train_samples is not None: + assert(margs.consumed_train_samples == consumed_train_samples) + else: + consumed_train_samples = margs.consumed_train_samples + if consumed_valid_samples is not None: + assert(margs.consumed_valid_samples == consumed_valid_samples) + else: + consumed_valid_samples = margs.consumed_valid_samples + for vp_rank in range(model_array_len): + models[vp_rank][ep_rank].append(model_[vp_rank]) + + # Print memory usage. + print_memory_usage("loader", tp_rank, tp_size) + + return models + + set_global_variables(margs, build_tokenizer=False) + mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) + mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) + mpu.set_expert_model_parallel_world_size(margs.expert_model_parallel_size) + mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) + fused_kernels.load(margs) + + # Get true (non-padded) vocab size + if args.true_vocab_size is not None: + true_vocab_size = args.true_vocab_size + elif args.vocab_file is not None: + with open(args.vocab_file) as vocab_file_handler: # pylint: disable=unspecified-encoding + vocab = json.load(vocab_file_handler) + true_vocab_size = len(vocab) + if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size: + print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.") + queue.put("exit") + sys.exit(1) + else: + true_vocab_size = None + + # short aliases + tp_size = margs.tensor_model_parallel_size + pp_size = margs.pipeline_model_parallel_size + ep_size = margs.expert_model_parallel_size + vp_size = margs.virtual_pipeline_model_parallel_size + if vp_size is None: + vp_size = 1 + + # Layernorm has bias; RMSNorm does not. + if hasattr(checkpoint_args, 'normalization'): + norm_has_bias = checkpoint_args.normalization == "LayerNorm" + else: + # older models only supported LayerNorm + norm_has_bias = True + + # metadata + md = types.SimpleNamespace() + md.model_type = args.model_type + md.num_layers = margs.num_layers + md.hidden_size = margs.hidden_size + md.seq_length = margs.seq_length + md.num_attention_heads = margs.num_attention_heads + md.max_position_embeddings = margs.max_position_embeddings + md.tokenizer_type = margs.tokenizer_type + md.iteration = margs.iteration + md.params_dtype = margs.params_dtype + md.bert_binary_head = margs.bert_binary_head + md.output_layer = margs.untie_embeddings_and_output_weights + md.position_embedding_type = margs.position_embedding_type + md.linear_bias = margs.add_bias_linear + md.norm_has_bias = norm_has_bias + md.swiglu = margs.swiglu + md.previous_tensor_parallel_size = margs.tensor_model_parallel_size + md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size + md.previous_expert_parallel_size = margs.expert_model_parallel_size + md.true_vocab_size = true_vocab_size + md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by + md.checkpoint_args = checkpoint_args + md.use_legacy_models = margs.use_legacy_models + md.num_experts = margs.num_experts + + # Get transformer block (named either 'encoder' or 'decoder'). + transformer_block_key = get_mcore_transformer_block_key(md.model_type) + def get_transformer_block(_model): + return getattr(_model, transformer_block_key) + + # Get first pipe stage + mpu.set_pipeline_model_parallel_rank(0) + # all_models: pp_rank, vp_rank, ep_rank, tp_rank + all_models = [get_models(tp_size, ep_size, md.params_dtype)] + models = all_models[0][0] + if ep_size == 1: + assert len(models) == 1 + + md.consumed_train_samples = consumed_train_samples + md.consumed_valid_samples = consumed_valid_samples + queue.put(md) + + def queue_put(name, msg): + print(f"sending {name}") + msg["name"] = name + queue.put(msg) + + # Send embeddings + message = { + "word embeddings": torch.cat( + [models[0][tp_rank].embedding.word_embeddings.weight.data for tp_rank in range(tp_size)], + dim = 0) + } + if md.position_embedding_type == 'learned_absolute': + message["position embeddings"] = models[0].embedding.position_embeddings.weight.data + else: + assert not hasattr(models[0][0].embedding, 'position_embeddings') + + queue_put("embeddings", message) + + def get_message_for_dense_model(message): + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(models[0][0]).layers[layer_num] + message["input norm weight"] = layer.self_attention.linear_qkv.layer_norm_weight.data + if norm_has_bias: + message["input norm bias"] = layer.self_attention.linear_qkv.layer_norm_bias.data + message["post norm weight"] = layer.mlp.linear_fc1.layer_norm_weight.data + if norm_has_bias: + message["post norm bias"] = layer.mlp.linear_fc1.layer_norm_bias.data + if md.linear_bias: + message["dense bias"] = layer.self_attention.linear_proj.bias.data + + # Grab all parallel tensors for this layer + qkv_weight = [] + qkv_bias = [] + dense_weight = [] + mlp_l0_weight = [] + mlp_l0_bias = [] + mlp_l1_weight = [] + for tp_rank, model in enumerate(models[0]): + layer = get_transformer_block(model).layers[layer_num] + qkv_weight.append(layer.self_attention.linear_qkv.weight.data) + dense_weight.append(layer.self_attention.linear_proj.weight.data) + mlp_l0_weight.append(layer.mlp.linear_fc1.weight.data) + mlp_l1_weight.append(layer.mlp.linear_fc2.weight.data) + if md.linear_bias: + qkv_bias.append(layer.self_attention.linear_qkv.bias.data) + mlp_l0_bias.append(layer.mlp.linear_fc1.bias.data) + if md.linear_bias: + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(models[0][0]).layers[layer_num] + mlp_l1_bias = layer.mlp.linear_fc2.bias.data + + # Handle gated linear units + if md.swiglu: + # concat all the first halves ('W's) and all the second halves ('V's) + for tp_rank in range(tp_size): + mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) + message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) + message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + else: + message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) + + # simple concat of the rest + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["dense weight"] = torch.cat(dense_weight, dim=1) + message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) + if md.linear_bias: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) + message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0) + message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0) + else: + message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) + message["mlp l1 bias"] = mlp_l1_bias + + def get_message_for_moe_model(message): + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(models[0][0]).layers[layer_num] + message["input norm weight"] = layer.self_attention.linear_qkv.layer_norm_weight.data + if norm_has_bias: + message["input norm bias"] = layer.self_attention.linear_qkv.layer_norm_bias.data + message["post norm weight"] = layer.pre_mlp_layernorm.weight.data + if norm_has_bias: + message["post norm bias"] = layer.pre_mlp_layernorm.bias.data + if md.linear_bias: + message["dense bias"] = layer.self_attention.linear_proj.bias.data + + # Grab all parallel tensors for this layer + qkv_weight = [] + qkv_bias = [] + dense_weight = [] + mlp_l0_weight_list = [[] for _ in range(margs.num_experts)] + mlp_l0_bias_list = [[] for _ in range(margs.num_experts)] + mlp_l1_weight_list = [[] for _ in range(margs.num_experts)] + mlp_l1_bias_list = [[] for _ in range(margs.num_experts)] + router_weight = [] + + # Dense modules + for tp_rank, model in enumerate(models[0]): + layer = get_transformer_block(model).layers[layer_num] + qkv_weight.append(layer.self_attention.linear_qkv.weight.data) + dense_weight.append(layer.self_attention.linear_proj.weight.data) + if md.linear_bias: + qkv_bias.append(layer.self_attention.linear_qkv.bias.data) + layer = get_transformer_block(models[0][0]).layers[layer_num] + router_weight = layer.mlp.router.weight.data + + # MoE modules + num_experts_per_rank = margs.num_experts // ep_size + for ep_rank, tp_models in enumerate(models): + for tp_rank, model in enumerate(tp_models): + layer = get_transformer_block(model).layers[layer_num] + for local_expert_idx in range(num_experts_per_rank): + expert_idx = int(ep_rank * num_experts_per_rank + local_expert_idx) + mlp_l0_weight_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc1.weight.data) + mlp_l1_weight_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc2.weight.data) + if md.linear_bias: + mlp_l0_bias_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc1.bias.data) + + if md.linear_bias: + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(tp_models[0]) + for local_expert_idx in range(num_experts_per_rank): + expert_idx = int(ep_rank * num_experts_per_rank + local_expert_idx) + mlp_l1_bias_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc2.bias.data) + + mlp_l0_weight_w_list = [[] for _ in range(margs.num_experts)] + mlp_l0_weight_v_list = [[] for _ in range(margs.num_experts)] + # Concat along the tensor parallel dimension + for expert_idx in range(margs.num_experts): + mlp_l0_weight = mlp_l0_weight_list[expert_idx] + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) + mlp_l0_weight_w_list[expert_idx] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) + mlp_l0_weight_v_list[expert_idx] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + else: + mlp_l0_weight_list[expert_idx] = torch.cat(mlp_l0_weight, dim=0) + mlp_l1_weight_list[expert_idx] = torch.cat(mlp_l1_weight_list[expert_idx], dim=1) + + # Stack along the expert parallel dimension + if md.swiglu: + message["mlp l0 weight W"] = torch.stack(mlp_l0_weight_w_list) + message["mlp l0 weight V"] = torch.stack(mlp_l0_weight_v_list) + else: + message["mlp l0 weight"] = torch.stack(mlp_l0_weight_list) + message["mlp l1 weight"] = torch.stack(mlp_l1_weight_list) + + # Concat along TP and stack along EP to biases + if md.linear_bias: + mlp_l0_bias_w_list = [[] for _ in range(margs.num_experts)] + mlp_l0_bias_v_list = [[] for _ in range(margs.num_experts)] + # Concat along the tensor parallel dimension + for expert_idx in range(margs.num_experts): + mlp_l0_bias = mlp_l0_bias_list[expert_idx] + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) + mlp_l0_bias_w_list[expert_idx] = torch.cat([w[0] for w in mlp_l0_bias], dim=0) + mlp_l0_bias_v_list[expert_idx] = torch.cat([w[1] for w in mlp_l0_bias], dim=0) + else: + mlp_l0_bias_list[expert_idx] = torch.cat(mlp_l0_bias, dim=0) + assert len(mlp_l1_bias_list[expert_idx]) == 1 + mlp_l1_bias_list[expert_idx] = mlp_l1_bias_list[expert_idx][0] + + # Stack along the expert parallel dimension + if md.swiglu: + message["mlp l0 bias W"] = torch.stack(mlp_l0_bias_w_list) + message["mlp l0 bias V"] = torch.stack(mlp_l0_bias_v_list) + else: + message["mlp l0 bias"] = torch.stack(mlp_l0_bias_list) + message["mlp l1 bias"] = torch.stack(mlp_l1_bias_list) + + # Simple concat of the rest + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["dense weight"] = torch.cat(dense_weight, dim=1) + if md.linear_bias: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + + # Do nothing to router + message["router weight"] = router_weight + + total_layer_num = 0 + for vp_rank in range(vp_size): + mpu.set_virtual_pipeline_model_parallel_rank(vp_rank) + for pp_rank in range(pp_size): + if pp_rank > 0: + mpu.set_pipeline_model_parallel_rank(pp_rank) + if vp_rank == 0: + all_models.append(get_models(tp_size, ep_size, md.params_dtype)) + models = all_models[pp_rank][vp_rank] + for layer_num in range(len(get_transformer_block(models[0][0]).layers)): + message = {} + + if margs.num_experts: + get_message_for_moe_model(message) + else: + get_message_for_dense_model(message) + + queue_put(f"transformer layer {total_layer_num}", message) + + total_layer_num = total_layer_num + 1 + + # Send final norm from tp_rank 0 + message = { + "weight": get_transformer_block(models[0][0]).final_layernorm.weight.data, + } + if norm_has_bias: + message["bias"] = get_transformer_block(models[0][0]).final_layernorm.bias.data + queue_put("final norm", message) + + if md.output_layer: + message = { + "weight": torch.cat( + [models[0][tp_rank].output_layer.weight.data for tp_rank in range(tp_size)], + dim = 0) + } + queue_put("output layer", message) + + + # Send BERT lm head and binary head if it exists + if md.model_type == 'BERT': + message = { + "weight": models[0][0].pooler.dense.weight.data, + "bias": models[0][0].pooler.dense.bias.data + } + queue_put("pooler", message) + + message = { + "dense weight": models[0][0].lm_head.dense.weight.data, + "dense bias": models[0][0].lm_head.dense.bias.data, + "norm weight": models[0][0].lm_head.layer_norm.weight.data, + } + if norm_has_bias: + message["norm bias"] = models[0][0].lm_head.layer_norm.bias.data + queue_put("lm head", message) + + if md.bert_binary_head: + message = { + "weight": models[0][0].binary_head.weight.data, + "bias": models[0][0].binary_head.bias.data + } + queue_put("binary head", message) + queue.put("done") + +def load_checkpoint(queue, args): + try: + _load_checkpoint(queue, args) + except: + queue.put("exit") + raise diff --git a/chatlearn/tools/megatron_to_hf.py b/chatlearn/tools/megatron_to_hf.py index 7a990744..59576ccd 100644 --- a/chatlearn/tools/megatron_to_hf.py +++ b/chatlearn/tools/megatron_to_hf.py @@ -76,6 +76,20 @@ def add_checkpointing_args(parser): "Path to Megatron-LM" ), ) + parser.add_argument( + "--use_legacy_models", + action="store_true", + help=( + "Whether using legacy models. Default: False." + ) + ) + parser.add_argument( + "--validate_checkpoint", + action="store_true", + help=( + "Whether validating converted checkpoint. Default: False." + ) + ) return parser @@ -97,6 +111,16 @@ def add_checkpointing_args(parser): "mlp.dense_4h_to_h.weight" ] + +mcore_to_transformers = { + "self_attention.linear_proj":".self_attn.o_proj.", + "linear_fc1_1":".w1.", + "linear_fc1_2":".w3.", + "linear_fc2":".w2.", + "mlp.router":".block_sparse_moe.gate.", + "self_attention.rotary_emb":".self_attn.rotary_emb.inv_freq" # unneeded for MoE +} + def recursive_print(name, val, spaces=0): """ Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` @@ -383,13 +407,259 @@ def convert_checkpoint_from_megatron_to_transformers(args): os.system(f"cp {fn} {args.save_path}") +def convert_checkpoint_from_mcore_to_transformers(args): + """ + Convert NVIDIA MCore checkpoint to HuggingFace Transformers checkpoint. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Load Megatron-Core checkpoint arguments from the state dict + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000", "mp_rank_00_000_000"] + for root, dirnames, _ in os.walk(args.load_path): + for dirname in dirnames: + if dirname in possible_sub_dirs: + rank0_checkpoint_name = glob.glob(os.path.join(root, dirname) + "/*.pt") + args.load_path = root + rank0_checkpoint_path = rank0_checkpoint_name[0] + + print(f"Loading Megatron-Core checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-Core checkpoint does not contain arguments. This utility only supports Megatron-Core checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-Core checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + output_state_dict = {} + + checkpoint_version = state_dict.get("checkpoint_version", 0.0) + assert checkpoint_version >= 3.0 + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + ep_size = megatron_args.expert_model_parallel_size + assert tp_size == 1 and pp_size == 1 and ep_size == 1 + + # Possible keys for MoE models: + # 'embedding.word_embeddings.weight', + # 'decoder.layers.0.self_attention.linear_proj.weight', + # 'decoder.layers.0.self_attention.linear_proj._extra_state', + # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight', + # 'decoder.layers.0.self_attention.linear_qkv.weight', + # 'decoder.layers.0.self_attention.linear_qkv._extra_state', + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # 'decoder.layers.0.mlp.experts.local_experts.0.linear_fc1.weight', + # 'decoder.layers.0.mlp.experts.local_experts.0.linear_fc1._extra_state', + # 'decoder.layers.0.mlp.experts.local_experts.0.linear_fc2.weight', + # ..., + # 'decoder.final_layernorm.weight', + # 'output_layer.weight', + # 'output_layer._extra_state', + # 'decoder' + # The regex to extract layer names. + layer_re = re.compile(r"decoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z_]+)") + expert_re = re.compile(r"decoder.layers\.(\d+)\.([a-z0-9_.]+)\.(\d+)\.([a-z0-9_.]+)\.(weight|bias|_extra_state)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + state_dict = tp_state_dicts[0]['model'] + + # Convert and store the position embeddings. + position_embeddings = state_dict.get("embedding.position_embeddings.weight", None) + if position_embeddings: + output_state_dict["transformer.position_embeddings.weight"] = position_embeddings.to(dtype) + + # Convert and store the word embeddings. + word_embedding = state_dict.get("embedding.word_embeddings.weight", None) + output_state_dict["model.embed_tokens.weight"] = word_embedding.to(dtype) + + # Transformer Layers + print("Converting transformer layers") + + def process_dense(layer_match_res, output_state_dict): + # The name of the operation. + op_name = layer_match_res.group(2) + # Is it a weight or a bias? + weight_or_bias = layer_match_res.group(3) + + # Ignore them + if weight_or_bias in ('bias', '_extra_state'): + return + params = val.to(dtype) + + # For norm(s), simply store the norm. + if weight_or_bias.endswith("norm_weight"): # e.g. self_attention.linear_qkv.layer_norm_weight + ln_name = "input_layernorm" + output_state_dict[layer_name + "." + ln_name + ".weight"] = params + elif op_name.endswith("norm") and weight_or_bias == 'weight': # e.g. pre_mlp_layernorm.weight + ln_name = "post_attention_layernorm" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif op_name == "self_attention.linear_qkv" \ + and weight_or_bias == "weight": + q_proj, k_proj, v_proj = split_attn_state(params, megatron_args) + if args.model_type == "llama": + output_state_dict[layer_name + ".self_attn.q_proj.weight"] = q_proj + output_state_dict[layer_name + ".self_attn.k_proj.weight"] = k_proj + output_state_dict[layer_name + ".self_attn.v_proj.weight"] = v_proj + + # Store other weights such as router + elif weight_or_bias == "weight": + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params + + # Copy the Rotary Embedding + else: + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + out_name] = params + + def process_moe(expert_match_res, output_state_dict): + # The prefix of the expert + expert_prefix = expert_match_res.group(2) + # The index of the expert + expert_idx = expert_match_res.group(3) + # the name of the operation + op_name = expert_match_res.group(4) + # Is it a weight or a bias? + weight_or_bias = expert_match_res.group(5) + + # Ignore them + if weight_or_bias in ('bias', '_extra_state'): + return + params = val.to(dtype) + + expert_name = f".block_sparse_moe.experts.{expert_idx}" + if 'linear_fc1' in op_name: + linear_fc1_1, linear_fc1_2 = torch.split(params, params.size(0)//2, 0) + out_name = mcore_to_transformers[op_name+'_1'] + output_state_dict[layer_name + expert_name + out_name + "weight"] = linear_fc1_1 + out_name = mcore_to_transformers[op_name+'_2'] + output_state_dict[layer_name + expert_name + out_name + "weight"] = linear_fc1_2 + elif 'linear_fc2' in op_name: + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + expert_name + out_name + "weight"] = params + else: + assert False, f"Unrecognized MoE module {expert_prefix}.{expert_idx}.{op_name}" + + # Extract the layers. + for key, val in state_dict.items(): + # Match the name. + layer_match_res = layer_re.match(key) + expert_match_res = expert_re.match(key) + # Continue if that's not a layer + if layer_match_res is None: + continue + if val is None: + continue + + # The index of the layer. + layer_idx = int(layer_match_res.group(1)) + # The name of the layer. + layer_name = f"model.layers.{layer_idx}" + + if expert_match_res: # Deal with sparse layers + process_moe(expert_match_res, output_state_dict) + else: # Deal with dense layers + process_dense(layer_match_res, output_state_dict) + + if megatron_args.num_layers != (layer_idx + 1): + raise ValueError(f"Expected {megatron_args.num_layers} layers but found {layer_idx + 1}") + + # The final norm. + print("Converting final norm") + params = state_dict.get("decoder.final_layernorm.weight", None) + output_state_dict["model.norm.weight"] = params.to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + params = state_dict.get('output_layer.weight', None) + output_state_dict["lm_head.weight"] = params.to(dtype) + + print("Saving checkpoint...") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + + # Save the model + if not os.path.exists(args.save_path): + os.system(f'mkdir -p {args.save_path}') + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + # Saving config and tokenzier files + for fn in glob.glob(args.vocab_dir + "/*"): + if (fn.endswith(".json") or fn.endswith("tokenizer.model") or fn.endswith(".py")) and not fn.endswith(".index.json"): + os.system(f"cp {fn} {args.save_path}") + + # It should be done! + print("Conversion from Megatron-Core to Transformers is done!") + +# pylint: disable=import-outside-toplevel +def validate_loading_checkpoints(args): + from transformers import AutoModelForCausalLM + _, model_loading_info = AutoModelForCausalLM.from_pretrained(args.save_path, output_loading_info=True) + if len(model_loading_info["missing_keys"]) > 0: + assert False, f"Invalid model checkpoint on missing_keys: {model_loading_info['missing_keys']}" + if len(model_loading_info["unexpected_keys"]) > 0: + assert False, f"Invalid model checkpoint on unexpected_keys: {model_loading_info['unexpected_keys']}" + if len(model_loading_info["mismatched_keys"]) > 0: + assert False, f"Invalid model checkpoint on mismatched_keys: {model_loading_info['mismatched_keys']}" + if len(model_loading_info["error_msgs"]) > 0: + assert False, f"Invalid model checkpoint on error_msgs: {model_loading_info['error_msgs']}" + def main(): parser = argparse.ArgumentParser() parser = add_checkpointing_args(parser) args = parser.parse_args() if args.megatron_path: sys.path.append(args.megatron_path) - convert_checkpoint_from_megatron_to_transformers(args) + + if args.use_legacy_models: + convert_checkpoint_from_megatron_to_transformers(args) + else: + convert_checkpoint_from_mcore_to_transformers(args) + + if args.validate_checkpoint: + print("Validating converted checkpoints...") + validate_loading_checkpoints(args) + print("Validation success!") if __name__ == "__main__": diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index 145ce9d7..b56dedd2 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -198,7 +198,7 @@ class ModelConfig(BaseConfig): #: [optional] cpu per process cpu_per_process: int = None #: [optional] number of module replica, - #: for gpu model, num_replica = num_gpu // (TP * PP * DP), + #: for gpu model, num_replica = num_gpu // (TP * PP * DP * EP), #: for cpu model, num_replica = num_cpu // cpu_per_process num_replica: int = 1 #: [required] whether model is trainable @@ -207,6 +207,8 @@ class ModelConfig(BaseConfig): tensor_model_parallel_size: int = None #: [optional] pipeline model parallel size pipeline_model_parallel_size: int = None + #: [optional] expert model parallel size + expert_model_parallel_size: int = None #: [optional] zero size zero_size: int = None #: [optional] config file for model @@ -517,20 +519,21 @@ def _validate_params(self): if model_args.generation_batch_size is None or model_args.generation_batch_size <= 0: if self.runtime_args.generation_batch_size: model_args.generation_batch_size = self.runtime_args.generation_batch_size - for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "zero_size"]: + for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "expert_model_parallel_size", "zero_size"]: if model_args.args_dict.get(key) is not None: setattr(model_args, key, model_args.args_dict.get(key)) assert getattr(model_args, key) >= 1 elif getattr(model_args, key) is None: setattr(model_args, key, 1) - if model_args.tensor_model_parallel_size > 1 or model_args.pipeline_model_parallel_size > 1: + if model_args.tensor_model_parallel_size > 1 or model_args.pipeline_model_parallel_size > 1 or model_args.expert_model_parallel_size > 1: assert model_args.zero_size == 1 or model_args.zero_size is None assert model_args.num_gpu % ( - model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size) == 0, \ + model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size) == 0, \ "num_gpu must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size " \ f"for {model_name} model, but got num_gpu = {model_args.num_gpu}" \ - f"tensor_model_parallel_size = {model_args.tensor_model_parallel_size}, and " \ - f"pipeline_model_parallel_size = {model_args.pipeline_model_parallel_size}." + f"tensor_model_parallel_size = {model_args.tensor_model_parallel_size}, " \ + f"pipeline_model_parallel_size = {model_args.pipeline_model_parallel_size}, and "\ + f"expert_model_parallel_size = {model_args.expert_model_parallel_size}." assert model_args.num_gpu > 0 or model_args.num_cpu > 0, \ f"{model_name} num_gpu: {model_args.num_gpu}, num_cpu: {model_args.num_cpu}, at least one of them should be set" @@ -540,7 +543,7 @@ def _validate_params(self): model_args.num_replica = model_args.num_gpu // model_args.zero_size else: model_args.num_replica = model_args.num_gpu // ( - model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size) + model_args.tensor_model_parallel_size * model_args.pipeline_model_parallel_size * model_args.expert_model_parallel_size) elif model_args.num_cpu >= 1: model_args.num_replica = model_args.num_cpu // model_args.cpu_per_process assert model_args.num_replica * model_args.generation_batch_size <= self.runtime_args.sample_per_episode, \ diff --git a/chatlearn/utils/megatron_import_helper.py b/chatlearn/utils/megatron_import_helper.py index 50a74bfe..92e782bf 100644 --- a/chatlearn/utils/megatron_import_helper.py +++ b/chatlearn/utils/megatron_import_helper.py @@ -160,12 +160,14 @@ from megatron.optimizer import DistributedOptimizer from megatron.optimizer.optimizer import MegatronOptimizer from megatron.optimizer.optimizer import MixedPrecisionOptimizer + from megatron.optimizer.optimizer import ChainedOptimizer from megatron.optimizer.optimizer import Float16OptimizerWithFloat16Params except ImportError: from megatron.core.optimizer import get_megatron_optimizer from megatron.core.optimizer import DistributedOptimizer from megatron.core.optimizer.optimizer import MegatronOptimizer from megatron.core.optimizer.optimizer import MixedPrecisionOptimizer + from megatron.core.optimizer.optimizer import ChainedOptimizer from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params # DistributedDataParallel diff --git a/chatlearn/utils/megatron_utils.py b/chatlearn/utils/megatron_utils.py index 0d5c4296..967c51a2 100644 --- a/chatlearn/utils/megatron_utils.py +++ b/chatlearn/utils/megatron_utils.py @@ -168,15 +168,20 @@ def load_checkpoint(*_args, **kwargs): args = get_args() target_tp = args.tensor_model_parallel_size target_pp = args.pipeline_model_parallel_size + target_ep = args.expert_model_parallel_size state_dict, _, _ = _load_base_checkpoint(args.load, rank0=True) args.iteration = state_dict['iteration'] checkpoint_args = state_dict['args'] checkpoint_tp = checkpoint_args.tensor_model_parallel_size checkpoint_pp = checkpoint_args.pipeline_model_parallel_size - if target_tp != checkpoint_tp or target_pp != checkpoint_pp: + checkpoint_ep = checkpoint_args.expert_model_parallel_size + if target_tp != checkpoint_tp or target_pp != checkpoint_pp or target_ep != checkpoint_ep: script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tools/megatron_checkpoint_utils.py") save_dir = args.load[:-1] if args.load.endswith("/") else args.load - save_dir = save_dir + f"-transform-tp{target_tp}-pp{target_pp}" + if target_ep is None or target_ep == 1: + save_dir = save_dir + f"-transform-tp{target_tp}-pp{target_pp}" + else: + save_dir = save_dir + f"-transform_tp{target_tp}-pp{target_pp}-ep{target_ep}" if not os.path.exists(save_dir): # use last rank so we can determin model_type by whether last pipeline stage contains pooler_head if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): diff --git a/examples/megatron/configs/mixtral/base.yaml b/examples/megatron/configs/mixtral/base.yaml new file mode 100644 index 00000000..c1f0ae8f --- /dev/null +++ b/examples/megatron/configs/mixtral/base.yaml @@ -0,0 +1,83 @@ +# mixtral-8x7b config +add_position_embedding: False +use_rotary_position_embeddings: True +untie_embeddings_and_output_weights: True +tokenizer_type: Llama2Tokenizer +exit_on_missing_checkpoint: True +normalization: RMSNorm +masked_softmax_fusion: False +apply_query_key_layer_scaling: False +use_checkpoint_args: False +add_bias_linear: False +swiglu: True +attention_softmax_in_fp32: True +transformer_impl: transformer_engine +bf16: True + + +trainer_engine: ${trainer_engine:rlhf} +init_shuffle_prompts: ${init_shuffle_prompts:0} +# dpo loss +use_ipo: ${use_ipo:False} +dpo_weight: ${dpo_weight:0.1} + +train_to_compare_num_responses: ${train_to_compare_num_responses:1} +num_inference_per_prompt: ${num_inference_per_prompt:1} +tokenizer_model: ${tokenizer_model} +max_position_embeddings: ${max_position_embedding:4096} +seq_length: ${seq_length:1024} +fix_kl_coef: ${fix_kl_coef:True} +log_dir: ${log_dir} +exp_name: ${exp_name:test} +tensorboard_dir: ${tensorboard_dir} +loss_on_prompts: ${loss_on_prompts:False} +numerical_stable: True + +build_path: ${build_path:build} + + +init_kl_coef: ${init_kl_coef:0.02} +target: 6 +horizon: 10000 +gamma: 1 +lam: 0.95 +cliprange: 0.2 +cliprange_value: ${cliprange_value:0.1} +scale_reward: "None" + +cliprange_reward: 100 + +max_new_tokens: ${max_new_tokens:512} + + +ngram_coef: ${ngram_coef:1} +lm_coef: ${lm_coef:0} +math_coef: ${math_coef:0} +raw_reward_coeff: ${raw_reward_coeff:1} + +clipped_value_only: ${clipped_value_only:1} +finetune: True + + +save: ${save_dir} +save_interval: 1000 +gradient_accumulation_fusion: 0 +max_tokens_to_oom: 99999999 + + +hysteresis: 2 +use_flash_attn: 1 +do_math_eval: 0 +log_entropy: False +adaptive_parallel_strategy_on_checkpoint: True +log_interval: 1 +distributed_timeout_minutes: 30 +make_vocab_size_divisible_by: 32 +use_legacy_models: ${use_legacy_models:False} # disable legacy mode for MoE models +use_dist_ckpt: ${use_dist_ckpt:True} # use dist_ckpt for MoE models + +# MoE +moe_router_load_balancing_type: ${moe_router_load_balancing_type:"aux_loss"} +moe_aux_loss_coeff: ${moe_aux_loss_coeff:1e-2} +moe_grouped_gemm: ${moe_grouped_gemm:True} +moe_token_dispatcher_type: ${moe_token_dispatcher_type:"alltoall"} diff --git a/examples/megatron/configs/mixtral/base_inference.yaml b/examples/megatron/configs/mixtral/base_inference.yaml new file mode 100644 index 00000000..a266dd7e --- /dev/null +++ b/examples/megatron/configs/mixtral/base_inference.yaml @@ -0,0 +1,16 @@ +includes: + - base.yaml + + +temperature: 1.0 +seed: 42 +no_load_optim: True +no_load_rng: True +no_load_args: True +no_load_scheduler: True +log_num_zeros_in_grad: True +attention_dropout: 0.0 +hidden_dropout: 0.0 +retro_encoder_attention_dropout: 0.0 +retro_encoder_hidden_dropout: 0.0 +inference_batch_times_seqlen_threshold: ${inference_batch_times_seqlen_threshold:4096} \ No newline at end of file diff --git a/examples/megatron/configs/mixtral/base_train.yaml b/examples/megatron/configs/mixtral/base_train.yaml new file mode 100644 index 00000000..3c17b5ad --- /dev/null +++ b/examples/megatron/configs/mixtral/base_train.yaml @@ -0,0 +1,10 @@ +includes: + - base.yaml + + +distributed_backend: nccl +train_iters: 12000 + +clip_grad: ${clip_grad:0.5} +log_interval: 1 +log_num_zeros_in_grad: True diff --git a/examples/megatron/configs/mixtral/dpo.yaml b/examples/megatron/configs/mixtral/dpo.yaml new file mode 100644 index 00000000..f6c751fe --- /dev/null +++ b/examples/megatron/configs/mixtral/dpo.yaml @@ -0,0 +1,46 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + reference: + model_config_file: reference.yaml + num_gpu: ${num_gpu_ref:16} + trainable: False + generation_batch_size: ${ref_generation_batch_size:4} + free_memory: ${free_memory_reference:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + +runtime: + colocation: + - ppo_policy,reference + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:100} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:100} + data_path: ${data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + output_dir: ${output_dir} + free_sync_collective_group: ${free_sync_collective_group:False} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/configs/mixtral/old_policy_inference.yaml b/examples/megatron/configs/mixtral/old_policy_inference.yaml new file mode 100644 index 00000000..9faa42ac --- /dev/null +++ b/examples/megatron/configs/mixtral/old_policy_inference.yaml @@ -0,0 +1,15 @@ +includes: + - base_inference.yaml + - policy_shared.yaml + + +top_p: ${policy_top_p:0.9} +top_k: ${policy_top_k:0} +temperature: ${policy_temperature:1.0} + +eval_temperature: 0.01 +use_attn_acc: ${use_attn_acc:False} +eval_top_k: 1 +eval_top_p: 0 + +pipeline_model_parallel_size: ${policy_pp:1} diff --git a/examples/megatron/configs/mixtral/old_value_inference.yaml b/examples/megatron/configs/mixtral/old_value_inference.yaml new file mode 100644 index 00000000..54fc4dfc --- /dev/null +++ b/examples/megatron/configs/mixtral/old_value_inference.yaml @@ -0,0 +1,5 @@ +includes: + - base_inference.yaml + - reward_shared.yaml + +pipeline_model_parallel_size: ${value_pp:1} diff --git a/examples/megatron/configs/mixtral/online_dpo.yaml b/examples/megatron/configs/mixtral/online_dpo.yaml new file mode 100644 index 00000000..d818c10f --- /dev/null +++ b/examples/megatron/configs/mixtral/online_dpo.yaml @@ -0,0 +1,62 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu_policy:16} + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + free_memory: ${free_memory_policy:False} + + reference: + model_config_file: reference.yaml + num_gpu: ${num_gpu_ref:16} + trainable: False + generation_batch_size: ${ref_generation_batch_size:4} + free_memory: ${free_memory_reference:False} + + reward: + model_config_file: reward_inference.yaml + num_gpu: ${num_gpu_reward:16} + trainable: False + free_memory: ${free_memory_reward:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + +runtime: + colocation: + - policy,ppo_policy,reward,reference + generation_batch_size: ${generation_batch_size:4} + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:100} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:100} + data_path: ${data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + output_dir: ${output_dir} + free_sync_collective_group: ${free_sync_collective_group:False} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/configs/mixtral/policy_shared.yaml b/examples/megatron/configs/mixtral/policy_shared.yaml new file mode 100644 index 00000000..09436f59 --- /dev/null +++ b/examples/megatron/configs/mixtral/policy_shared.yaml @@ -0,0 +1,14 @@ +load: ${policy_inference_load} +load_iteration: ${policy_load_iteration} +num_layers: ${policy_num_layers} +hidden_size: ${policy_hidden_size} +num_attention_heads: ${policy_num_attention_heads} +ffn_hidden_size: ${policy_ffn_hidden_size} +num_experts: ${policy_num_experts} +moe_router_topk: ${policy_moe_router_topk} +tensor_model_parallel_size: ${policy_tp:1} +expert_model_parallel_size: ${policy_ep:8} +group_query_attention: ${group_query_attention:True} +num_query_groups: ${policy_num_query_groups} +use_distributed_optimizer: True + diff --git a/examples/megatron/configs/mixtral/ppo_policy.yaml b/examples/megatron/configs/mixtral/ppo_policy.yaml new file mode 100644 index 00000000..ebb30b01 --- /dev/null +++ b/examples/megatron/configs/mixtral/ppo_policy.yaml @@ -0,0 +1,39 @@ +includes: + - base_train.yaml + - policy_shared.yaml + + +bf16: True +use_checkpoint_opt_param_scheduler: False +adam_beta1: 0.9 +adam_beta2: 0.95 +num_workers: 8 +init_method_std: 0.006 + +# dropout +attention_dropout: ${attention_dropout:0.1} +hidden_dropout: ${hidden_dropout:0.1} +retro_encoder_hidden_dropout: ${retro_encoder_hidden_dropout:0.1} +retro_encoder_attention_dropout: ${retro_encoder_attention_dropout:0.1} + +recompute_granularity: selective + +log_num_zeros_in_grad: True +no_load_optim: True +no_load_rng: True +no_load_args: True +no_load_scheduler: True + + +lr_decay_iters: 12000 +lr_warmup_iters: 100 +lr: ${policy_lr:2.4e-7} +min_lr: ${policy_min_lr:1e-9} +lr_decay_style: ${policy_lr_decay_style:linear} +weight_decay: 0.01 +pipeline_model_parallel_size: ${ppo_policy_pp:1} +sequence_parallel: ${sequence_parallel:True} + +recompute_activations: ${policy_recompute_activations:False} +recompute_granularity: ${policy_recompute_granularity:None} +moe_layer_recompute: ${policy_moe_layer_recompute:False} \ No newline at end of file diff --git a/examples/megatron/configs/mixtral/ppo_value.yaml b/examples/megatron/configs/mixtral/ppo_value.yaml new file mode 100644 index 00000000..d524d30f --- /dev/null +++ b/examples/megatron/configs/mixtral/ppo_value.yaml @@ -0,0 +1,30 @@ +includes: + - base_train.yaml + - reward_shared.yaml + +pipeline_model_parallel_size: ${ppo_value_pp:1} +lr_decay_iters: 12000 +lr_warmup_iters: 100 +lr: ${value_lr:5e-6} +min_lr: ${value_min_lr:5e-7} +lr_decay_style: ${value_lr_decay_style:linear} +weight_decay: 0.01 +log_interval: 1 + +use_checkpoint_opt_param_scheduler: False +adam_beta1: 0.9 +adam_beta2: 0.95 +num_workers: 8 +init_method_std: 0.006 + +recompute_granularity: selective + +no_load_optim: True +no_load_rng: True +no_load_args: True +no_load_scheduler: True +sequence_parallel: True + +recompute_activations: ${value_recompute_activations:False} +recompute_granularity: ${value_recompute_granularity:None} +moe_layer_recompute: ${value_moe_layer_recompute:False} \ No newline at end of file diff --git a/examples/megatron/configs/mixtral/reference.yaml b/examples/megatron/configs/mixtral/reference.yaml new file mode 100644 index 00000000..96cb77a2 --- /dev/null +++ b/examples/megatron/configs/mixtral/reference.yaml @@ -0,0 +1,6 @@ +includes: + - base_inference.yaml + - policy_shared.yaml + +parallel_output: True +pipeline_model_parallel_size: ${ref_pp:1} diff --git a/examples/megatron/configs/mixtral/reward_inference.yaml b/examples/megatron/configs/mixtral/reward_inference.yaml new file mode 100644 index 00000000..31e15f8b --- /dev/null +++ b/examples/megatron/configs/mixtral/reward_inference.yaml @@ -0,0 +1,6 @@ +includes: + - base_inference.yaml + - reward_shared.yaml + +reward_bias: 0 +pipeline_model_parallel_size: ${reward_pp:1} diff --git a/examples/megatron/configs/mixtral/reward_shared.yaml b/examples/megatron/configs/mixtral/reward_shared.yaml new file mode 100644 index 00000000..325e45eb --- /dev/null +++ b/examples/megatron/configs/mixtral/reward_shared.yaml @@ -0,0 +1,18 @@ +load: ${reward_load} +load_iteration: ${reward_load_iteration} +# enable use_distributed_optimizer will raise error for reward/value, disable temply +use_distributed_optimizer: False + +num_layers: ${reward_num_layers} +hidden_size: ${reward_hidden_size} +num_attention_heads: ${reward_num_attention_heads} +ffn_hidden_size: ${reward_ffn_hidden_size} +num_experts: ${reward_num_experts} +moe_router_topk: ${reward_moe_router_topk} +tensor_model_parallel_size: ${reward_tp:1} +expert_model_parallel_size: ${reward_ep:8} +group_query_attention: ${group_query_attention:True} +num_query_groups: ${reward_num_query_groups} + +save_inference: True +save_inference_interval: 10 \ No newline at end of file diff --git a/examples/megatron/configs/mixtral/rlhf.yaml b/examples/megatron/configs/mixtral/rlhf.yaml new file mode 100644 index 00000000..5a693881 --- /dev/null +++ b/examples/megatron/configs/mixtral/rlhf.yaml @@ -0,0 +1,81 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu_policy:16} + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + free_memory: ${free_memory_policy:False} + + reference: + model_config_file: reference.yaml + num_gpu: ${num_gpu_ref:16} + trainable: False + generation_batch_size: ${ref_generation_batch_size:4} + free_memory: ${free_memory_reference:False} + + reward: + model_config_file: reward_inference.yaml + num_gpu: ${num_gpu_reward:16} + trainable: False + free_memory: ${free_memory_reward:False} + + value: + model_config_file: old_value_inference.yaml + num_gpu: ${num_gpu_value:16} + trainable: False + free_memory: ${free_memory_value:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + + ppo_value: + model_config_file: ppo_value.yaml + num_gpu: ${num_gpu_ppo_value:16} + trainable: True + lora: + enable_lora: ${enable_lora_value:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_value:False} + +runtime: + colocation: + - policy,ppo_policy,reward,reference,value,ppo_value + generation_batch_size: ${generation_batch_size:4} + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:100} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:100} + data_path: ${data_path} + eval_data_path: ${eval_data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + free_sync_collective_group: ${free_sync_collective_group:False} + exp_name: ${exp_name:chatlearn} + output_dir: ${output_dir} diff --git a/examples/megatron/data/sft_dataset.py b/examples/megatron/data/sft_dataset.py index abc65c5d..df1c7ef1 100644 --- a/examples/megatron/data/sft_dataset.py +++ b/examples/megatron/data/sft_dataset.py @@ -57,6 +57,10 @@ def __init__(self, data_path, max_seq_length): self.pad_id = self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else self.tokenizer.pad_id self.bos_id = self.tokenizer.bos_token_id if hasattr(self.tokenizer, 'bos_token_id') else self.tokenizer.bos_id self.eos_id = self.tokenizer.eod + # The pad_id for Llama2Tokenizer is -1. It will cause out-of-bound index for embedding calculation. + # Thus, we hardcode it as 0 here. + if self.pad_id == -1: + self.pad_id = 0 def __len__(self): return len(self.dataset) diff --git a/examples/megatron/models/policy_model.py b/examples/megatron/models/policy_model.py index bebed6a3..5bdc622d 100644 --- a/examples/megatron/models/policy_model.py +++ b/examples/megatron/models/policy_model.py @@ -18,13 +18,12 @@ from megatron.training import get_args from megatron.core import tensor_parallel -from megatron.training.arguments import core_transformer_config_from_args from megatron.training.global_vars import get_tokenizer from megatron.legacy.model.gpt_model import GPTModel from megatron.legacy.model.language_model import parallel_lm_logits from chatlearn.models.megatron.ops.policy_gradient import tensor_decomp_pg_loss -from .utils import get_advantages_and_returns, get_eos_id +from .utils import get_advantages_and_returns, has_config_in_args, get_eos_id from .constants import TrainerEngine from .constants import select_actions_from_right_padded @@ -40,9 +39,13 @@ def __init__(self, post_process=True, stats=None): self.args = get_args() - config = core_transformer_config_from_args(self.args) - super().__init__(config, num_tokentypes, parallel_output, pre_process, post_process) - + if has_config_in_args(GPTModel): + # new API + from megatron.training.arguments import core_transformer_config_from_args # pylint: disable=import-outside-toplevel + config = core_transformer_config_from_args(self.args) + super().__init__(config, num_tokentypes, parallel_output, pre_process, post_process) + else: + super().__init__(num_tokentypes, parallel_output, pre_process, post_process) self.tokenizer = get_tokenizer() self.stats = stats diff --git a/examples/megatron/models/reward_inference.py b/examples/megatron/models/reward_inference.py index 3d3385b9..b50dcfb6 100644 --- a/examples/megatron/models/reward_inference.py +++ b/examples/megatron/models/reward_inference.py @@ -34,7 +34,6 @@ from chatlearn import MegatronModule from chatlearn.utils import to_device from chatlearn.utils.megatron_utils import load_checkpoint -from examples.megatron.data.reward_dataset import preprocess from .reward_model import RewardModel as LegacyRewardModel from .mcore_reward_model import MCoreRewardModel from .utils import tensorboard_scalar_dict, get_eos_id @@ -120,11 +119,7 @@ def setup(self): model = get_model(model_provider, wrap_with_ddp=False) if args.load: - torch.distributed.barrier() load_checkpoint(model, None, None, adaptive_parallel_strategy=args.adaptive_parallel_strategy_on_checkpoint) - torch.distributed.barrier() - else: - print_rank_0(f"Warning: Using random parameter for {self.name} model.") assert len(model) == 1, "Above condition should have caught this" self.model = model[0] diff --git a/examples/megatron/scripts/base_env.sh b/examples/megatron/scripts/base_env.sh index aa1bccce..7cb2079d 100644 --- a/examples/megatron/scripts/base_env.sh +++ b/examples/megatron/scripts/base_env.sh @@ -130,6 +130,23 @@ elif [[ "$model_size" == "llama2-70B" ]]; then export reward_ffn_hidden_size=28672 export reward_num_query_groups=8 export group_query_attention=True +elif [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_num_layers=32 + export policy_hidden_size=4096 + export policy_num_attention_heads=32 + export policy_num_query_groups=8 + export policy_ffn_hidden_size=14336 + export policy_num_experts=8 + export policy_moe_router_topk=2 + export reward_num_layers=32 + export reward_hidden_size=4096 + export reward_num_attention_heads=32 + export reward_num_query_groups=8 + export reward_ffn_hidden_size=14336 + export reward_num_experts=8 + export reward_moe_router_topk=2 + export max_position_embedding=32768 + export seq_length=2048 else echo "unsupported model_size ${model_size}, please set your own model config" exit 1 diff --git a/examples/megatron/scripts/convert_hf_to_megatron.sh b/examples/megatron/scripts/convert_hf_to_megatron.sh index 4127c6f2..26c53e9e 100644 --- a/examples/megatron/scripts/convert_hf_to_megatron.sh +++ b/examples/megatron/scripts/convert_hf_to_megatron.sh @@ -19,7 +19,11 @@ megatron=${MEGATRON} load_dir=${LOAD_PATH} save_dir=${SAVE_PATH} tokenizer_model=${TOKENIZER_MODEL} -model_size=${model_size:-llama2-7B} +if [[ $model == 'gpt_llama' ]]; then + model_size=${MODEL_SIZE:-llama2-7B} +elif [[ $model == 'mixtral' ]]; then + model_size=${MODEL_SIZE:-mixtral-8x7B} +fi export CUDA_DEVICE_MAX_CONNECTIONS=1 diff --git a/examples/megatron/scripts/convert_megatron_to_hf.sh b/examples/megatron/scripts/convert_megatron_to_hf.sh index e18c8cd4..63ad7f65 100644 --- a/examples/megatron/scripts/convert_megatron_to_hf.sh +++ b/examples/megatron/scripts/convert_megatron_to_hf.sh @@ -1,8 +1,9 @@ #!/bin/bash # Convert LLaMA model from megatron format to huggingface format. set -ex +set pipefail -# config +# path config chatlearn=${CHATLEARN} megatron=${MEGATRON} load_path=${LOAD_PATH} @@ -11,13 +12,29 @@ vocab_path=${VOCAB_PATH} target_params_dtype=${target_params_dtype:-bf16} temp_path=${save_path}/temp +# model config +# can be `gpt_llama' for GPT or Llama, or `mixtral' for Mixtral +model=${MODEL:-'gpt_llama'} + # Whether to use legacy models, default: True use_legacy_models=${USE_LEGACY_MODELS:-"True"} - -if [[ ${use_legacy_models} = "False" ]]; then - ckpt_format="mcore" +if [[ ${use_legacy_models} == "False" ]]; then + if [[ ${model} == 'gpt_llama' ]]; then + # TODO: migrate to mcore + loader_ckpt_format="mcore" + saver_ckpt_format="megatron" + elif [[ ${model} == 'mixtral' ]]; then + loader_ckpt_format="mcore_mixtral" + saver_ckpt_format="mcore" + else + echo -e "\033[31m Unrecognized model ${model} \033[0m" + exit -1 + fi + MCORE_ARGS="" else - ckpt_format="megatron" + loader_ckpt_format="megatron" + saver_ckpt_format="megatron" + MCORE_ARGS="--use_legacy_models" fi set +x @@ -25,18 +42,40 @@ set +x # convert parallel strategy START_TIME=$SECONDS -cd ${megatron} - -if [[ ! -d "${temp_path}" ]]; then - python tools/checkpoint/convert.py \ - --model-type GPT \ - --loader ${ckpt_format} \ - --saver "megatron" \ - --target-tensor-parallel-size 1 \ - --target-pipeline-parallel-size 1 \ - --load-dir ${load_path} \ - --save-dir ${temp_path} \ - --megatron-path ${megatron} +if [[ ! -d ${temp_path} ]]; then + if [[ ${model} == 'gpt_llama' ]]; then + cd ${megatron} + python tools/checkpoint/convert.py \ + --model-type GPT \ + --loader ${loader_ckpt_format} \ + --saver ${saver_ckpt_format} \ + --target-tensor-parallel-size 1 \ + --target-pipeline-parallel-size 1 \ + --load-dir ${load_path} \ + --save-dir ${temp_path} \ + --megatron-path ${megatron} + elif [[ ${model} == 'mixtral' ]]; then + cd ${chatlearn} + export PYTHONPATH=${chatlearn}:${megatron}:${megatron}/tools/checkpoint:${PYTHONPATH} + python chatlearn/tools/convert.py \ + --model-type GPT \ + --loader ${loader_ckpt_format} \ + --saver-prefix tools.checkpoint.saver \ + --saver ${saver_ckpt_format} \ + --target-tensor-parallel-size 1 \ + --target-pipeline-parallel-size 1 \ + --target-expert-parallel-size 1 \ + --load-dir ${load_path} \ + --save-dir ${temp_path} \ + --megatron-path ${megatron} + else + echo -e "\033[31m Unrecognized model ${model} \033[0m" + exit -1 + fi +fi + +if [[ $? != 0 ]]; then + exit $? fi # convert to hf format @@ -46,7 +85,8 @@ python chatlearn/tools/megatron_to_hf.py \ --save_path ${save_path} \ --target_params_dtype ${target_params_dtype} \ --vocab_dir ${vocab_path} \ - --megatron_path ${megatron} + --megatron_path ${megatron} \ + ${MCORE_ARGS} # clear temp path rm -r $temp_path diff --git a/examples/megatron/scripts/train_dpo_mixtral.sh b/examples/megatron/scripts/train_dpo_mixtral.sh new file mode 100644 index 00000000..2d81a268 --- /dev/null +++ b/examples/megatron/scripts/train_dpo_mixtral.sh @@ -0,0 +1,53 @@ +#!/bin/bash +set -x + +[ -z "$model_size" ] && export model_size=mixtral-8x7B + +# Get the directory of the current script +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source ${DIR}/base_env.sh + +export trainer_engine=dpo + +# clip +export clip_grad=5.0 + +# desable dropout +export attention_dropout=0.0 +export hidden_dropout=0.0 +export retro_encoder_hidden_dropout=0.0 +export retro_encoder_attention_dropout=0.0 + + +if [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_tp=1 + export policy_ep=8 + export ppo_policy_pp=4 + export ref_pp=4 + export train_global_batch_size=128 + export ref_generation_batch_size=2 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True +fi + +configs=$CHATLEARN/examples/megatron/configs/mixtral/dpo.yaml + +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 + +output_dir=${output_dir}/${exp_name} +export data_checkpoint_path=${output_dir}/data_checkpoint +mkdir -p $output_dir +log_file=${output_dir}/log_${RANK}.log + +policy_inference_load=${POLICY_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_gpu=${num_gpu} \ +data_path=${DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +python entry/train_dpo.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + + diff --git a/examples/megatron/scripts/train_online_dpo_mixtral.sh b/examples/megatron/scripts/train_online_dpo_mixtral.sh new file mode 100644 index 00000000..a8322190 --- /dev/null +++ b/examples/megatron/scripts/train_online_dpo_mixtral.sh @@ -0,0 +1,78 @@ +#!/bin/bash +set -x + +[ -z "$model_size" ] && export model_size=mixtral-8x7B + +# Get the directory of the current script +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source ${DIR}/base_env.sh + +# megatron +# TODO: support vllm +backend=${1:-megatron} + +if [[ "$backend" != "megatron" ]]; then + echo "ERROR: expect megatron backend, while "$backend + exit 1 +fi + + +export trainer_engine=online_dpo + +export train_to_compare_num_responses=8 +export num_inference_per_prompt=8 + + +if [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_tp=1 + export policy_ep=8 + export ppo_policy_pp=4 + export reward_tp=1 + export reward_ep=8 + export ppo_value_pp=4 + export ref_pp=4 + export policy_pp=4 + export reward_pp=4 + export value_pp=4 + export train_global_batch_size=128 + export generation_batch_size=128 + export ref_generation_batch_size=2 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True + export value_recompute_activations=True + export value_moe_layer_recompute=True +fi + +if [[ "$backend" == "megatron" ]]; then + configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo.yaml +else + export ENABLE_VLLM=True + if [ -z "$tokenizer_load" ];then + echo "please set path to hf tokenizer for vllm backend, download from huggingface source." + exit 1 + fi + configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo_vllm.yaml +fi + +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 + +output_dir=$output_dir/$exp_name +export data_checkpoint_path=${output_dir}/data_checkpoint + +mkdir -p ${output_dir} +log_file=${output_dir}/log_${RANK}.log + +policy_inference_load=${POLICY_LOAD} \ +reward_load_iteration=${REWARD_LOAD_ITERATION} \ +reward_load=${REWARD_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_gpu=${num_gpu} \ +data_path=${DATASET_PATH} \ +eval_data_path=${EVAL_DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +python entry/train_online_dpo.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + diff --git a/examples/megatron/scripts/train_reward_mixtral.sh b/examples/megatron/scripts/train_reward_mixtral.sh new file mode 100644 index 00000000..d311f6de --- /dev/null +++ b/examples/megatron/scripts/train_reward_mixtral.sh @@ -0,0 +1,163 @@ +#!/bin/bash +set -x + +[ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost +[ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1 +[ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=8 +[ -z "$RANK" ] && export RANK=0 +[ -z "$MASTER_PORT" ] && export MASTER_PORT=12456 + + +# check the path +[[ -z "${MEGATRON}" ]] && { echo "MEGATRON path is not set"; exit 1; } +[[ -z "${CHATLEARN}" ]] && { echo "CHATLEARN path is not set"; exit 1; } +[[ -z "${LOAD_PATH}" ]] && { echo "LOAD_PATH is not set"; exit 1; } +[[ -z "${TOKENIZER_MODEL}" ]] && { echo "TOKENIZER_MODEL is not set"; exit 1; } +[[ -z "${DATASET_PATH}" ]] && { echo "DATASET_PATH is not set"; exit 1; } + + +export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}:${CHATLEARN}/examples/megatron + +DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ + --nnodes ${WORLD_SIZE} \ + --node_rank ${RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT}" + +[ -z "$MODEL_SIZE" ] && export MODEL_SIZE="8x7B" + +if [ $MODEL_SIZE == "8x7B" ]; then + NUM_LAYERS=32 + HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=32 + FFN_HIDDEN_SIZE=14336 + MAX_POSITION_EMBEDDINGS=32768 + NUM_QUERY_GROUPS=8 + NUM_EXPERTS=8 + MOE_ROUTER_TOPK=2 + seq_len=2048 + tp=1 + pp=4 + ep=8 + mb=1 + gbs=64 +else + echo "Unrecognized MODEL_SIZE ${MODEL_SIZE}, choose from '8x7B'." + exit -1 +fi + +DIR=$(pwd) +DATETIME=$(date +'date_%y-%m-%d_time_%H-%M-%S') +mkdir -p $DIR/logs + +NODE_RANK=$RANK +NNODES=$WORLD_SIZE + + +dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp / $ep)) +gbs=$(($gbs * $dp)) + + +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/reward/mixtral_hh_reward_$(date +%F)_gpt_${MODEL_SIZE}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_len} + +mkdir -p $CHECKPOINT_PATH + +MODEL_ARGS=" +--disable-bias-linear \ +--seq-length $seq_len \ +--max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ +--num-layers ${NUM_LAYERS} \ +--hidden-size ${HIDDEN_SIZE} \ +--ffn-hidden-size ${FFN_HIDDEN_SIZE} \ +--num-attention-heads ${NUM_ATTN_HEADS} \ +--init-method-std 0.006 \ +--attention-dropout 0.0 \ +--hidden-dropout 0.0 \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--swiglu \ +--untie-embeddings-and-output-weights \ +--group-query-attention \ +--num-query-groups ${NUM_QUERY_GROUPS} \ +--no-masked-softmax-fusion \ +--no-position-embedding \ +--transformer-impl transformer_engine \ +--attention-softmax-in-fp32 " + +MOE_ARGS=" +--num-experts ${NUM_EXPERTS} \ +--moe-router-topk ${MOE_ROUTER_TOPK} \ +--moe-router-load-balancing-type aux_loss \ +--moe-aux-loss-coeff 1e-2 \ +--moe-token-dispatcher-type alltoall \ +--overlap-param-gather \ +--overlap-grad-reduce \ +--moe-layer-recompute" + +DATA_ARGS=" +--tokenizer-type Llama2Tokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--data-path $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl \ +--split 98,2,0 \ +--dataloader-type cyclic " + +TRAINING_ARGS=" +--micro-batch-size $mb \ +--global-batch-size $gbs \ +--lr 2.0e-5 \ +--train-iters 1000 \ +--lr-decay-iters 1000 \ +--lr-decay-style cosine \ +--min-lr 6.0e-12 \ +--weight-decay 0. \ +--lr-warmup-iters 40 \ +--clip-grad 1.0 \ +--bf16 \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--adam-beta1 0.9 \ +--adam-beta2 0.999 \ +--use-flash-attn \ +--finetune \ +--recompute-activations" + +MODEL_PARALLEL_ARGS=" +--tensor-model-parallel-size $tp \ +--pipeline-model-parallel-size $pp \ +--expert-model-parallel-size $ep \ +--use-distributed-optimizer \ +--sequence-parallel \ +--distributed-timeout-minutes 60 \ +" + +LOGGING_ARGS=" +--log-interval 1 \ +--eval-iters 10 \ +--eval-interval 1000 \ +--save-interval 1000 \ +--save $CHECKPOINT_PATH \ +--load $LOAD_PATH \ +--tensorboard-dir $CHECKPOINT_PATH \ +--tensorboard-log-interval 100 \ +--num-workers 8 \ +--no-load-rng \ +--no-load-optim \ +--log-timers-to-tensorboard \ +--log-batch-size-to-tensorboard \ +--log-validation-ppl-to-tensorboard \ +" + +log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +cd ${CHATLEARN}/examples/megatron/alignment/reward + +torchrun $DISTRIBUTED_ARGS \ + finetune_reward.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh new file mode 100644 index 00000000..2abf5595 --- /dev/null +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -0,0 +1,85 @@ +#!/bin/bash +set -x + +[ -z "$model_size" ] && export model_size=mixtral-8x7B + +# Get the directory of the current script +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source ${DIR}/base_env.sh + +# megatron +# TODO: support vllm +backend=${1:-megatron} +if [[ "$backend" != "megatron" ]]; then + echo "ERROR: expect megatron backend, while "$backend + exit 1 +fi + + +config_dir=${CHATLEARN}/examples/megatron/configs/ + +if [[ "$backend" == "megatron" ]]; then + configs=${config_dir}/mixtral/rlhf.yaml +else + export ENABLE_VLLM=True + if [ -z "$tokenizer_load" ];then + echo "please set path to hf tokenizer for vllm backend, download from huggingface source." + exit 1 + fi + configs=${config_dir}/mixtral/vllm_rlhf.yaml +fi + +export trainer_engine=rlhf + +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend + +output_dir=${output_dir}/${exp_name} +export data_checkpoint_path=${output_dir}/data_checkpoint + + +if [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_tp=1 + export policy_ep=8 + export ppo_policy_pp=4 + export reward_tp=1 + export reward_ep=8 + export ppo_value_pp=4 + export ref_pp=4 + export policy_pp=4 + export reward_pp=4 + export value_pp=4 + export train_global_batch_size=128 + export generation_batch_size=1 + export ref_generation_batch_size=1 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True + export value_recompute_activations=True + export value_moe_layer_recompute=True + export free_memory_policy=True + export free_memory_reference=True + export free_memory_reward=True + export free_memory_value=True + export free_memory_ppo_policy=True + export free_memory_ppo_value=True +fi + +mkdir -p ${output_dir} +log_file=${output_dir}/log_${RANK}.log +echo $log_file + +policy_inference_load=${POLICY_LOAD} \ +reward_load_iteration=${REWARD_LOAD_ITERATION} \ +reward_load=${REWARD_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_gpu=${num_gpu} \ +data_path=${DATASET_PATH} \ +eval_data_path=${EVAL_DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + + diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh new file mode 100644 index 00000000..c5418bf9 --- /dev/null +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -0,0 +1,159 @@ +#!/bin/bash +set -x + +[ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost +[ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1 +[ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=8 +[ -z "$RANK" ] && export RANK=0 +[ -z "$MASTER_PORT" ] && export MASTER_PORT=12456 + +DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ + --nnodes ${WORLD_SIZE} \ + --node_rank ${RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT}" + +# check the path +[[ -z "${MEGATRON}" ]] && { echo "MEGATRON path is not set"; exit 1; } +[[ -z "${CHATLEARN}" ]] && { echo "CHATLEARN path is not set"; exit 1; } +[[ -z "${LOAD_PATH}" ]] && { echo "LOAD_PATH is not set"; exit 1; } +[[ -z "${TOKENIZER_MODEL}" ]] && { echo "TOKENIZER_MODEL is not set"; exit 1; } +[[ -z "${DATASET_PATH}" ]] && { echo "DATASET_PATH is not set"; exit 1; } + + +export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHATLEARN} + +[ -z "$MODEL_SIZE" ] && export MODEL_SIZE="8x7B" + +if [ $MODEL_SIZE == "8x7B" ]; then + NUM_LAYERS=32 + HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=32 + FFN_HIDDEN_SIZE=14336 + MAX_POSITION_EMBEDDINGS=32768 + NUM_QUERY_GROUPS=8 + NUM_EXPERTS=8 + MOE_ROUTER_TOPK=2 + seq_len=2048 + tp=1 + pp=4 + ep=8 + mb=1 + gbs=64 + echo "Unrecognized MODEL_SIZE ${MODEL_SIZE}, choose from '8x7B'." + exit -1 +fi + +DIR=$(pwd) +DATETIME=$(date +'date_%y-%m-%d_time_%H-%M-%S') +mkdir -p $DIR/logs + +NODE_RANK=$RANK +NNODES=$WORLD_SIZE + + +dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp / $ep)) +gbs=$(($gbs * $dp)) + + +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/mixtral_hh_sft_$(date +%F)_gpt_${MODEL_SIZE}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_len} + +mkdir -p $CHECKPOINT_PATH + +MODEL_ARGS=" +--disable-bias-linear \ +--seq-length $seq_len \ +--max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ +--num-layers ${NUM_LAYERS} \ +--hidden-size ${HIDDEN_SIZE} \ +--ffn-hidden-size ${FFN_HIDDEN_SIZE} \ +--num-attention-heads ${NUM_ATTN_HEADS} \ +--init-method-std 0.006 \ +--attention-dropout 0.0 \ +--hidden-dropout 0.0 \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--swiglu \ +--untie-embeddings-and-output-weights \ +--group-query-attention \ +--num-query-groups ${NUM_QUERY_GROUPS} \ +--no-masked-softmax-fusion \ +--no-position-embedding \ +--transformer-impl transformer_engine \ +--attention-softmax-in-fp32 " + +MOE_ARGS=" +--num-experts ${NUM_EXPERTS} \ +--moe-router-topk ${MOE_ROUTER_TOPK} \ +--moe-router-load-balancing-type aux_loss \ +--moe-aux-loss-coeff 1e-2 \ +--moe-token-dispatcher-type alltoall \ +--overlap-param-gather \ +--overlap-grad-reduce " + +DATA_ARGS=" +--tokenizer-type Llama2Tokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--data-path $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl \ +--split 98,2,0 \ +--dataloader-type cyclic " + +TRAINING_ARGS=" +--micro-batch-size $mb \ +--global-batch-size $gbs \ +--lr 2.0e-5 \ +--train-iters 1000 \ +--lr-decay-iters 1000 \ +--lr-decay-style cosine \ +--min-lr 6.0e-12 \ +--weight-decay 0. \ +--lr-warmup-iters 40 \ +--clip-grad 1.0 \ +--bf16 \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--adam-beta1 0.9 \ +--adam-beta2 0.999 \ +--use-flash-attn \ +--finetune " + +MODEL_PARALLEL_ARGS=" +--tensor-model-parallel-size $tp \ +--pipeline-model-parallel-size $pp \ +--expert-model-parallel-size $ep \ +--use-distributed-optimizer \ +--sequence-parallel \ +--distributed-timeout-minutes 60 \ +" + +LOGGING_ARGS=" +--log-interval 1 \ +--eval-iters 10 \ +--eval-interval 1000 \ +--save-interval 1000 \ +--save $CHECKPOINT_PATH \ +--load $LOAD_PATH \ +--tensorboard-dir $CHECKPOINT_PATH \ +--tensorboard-log-interval 100 \ +--num-workers 8 \ +--no-load-rng \ +--no-load-optim \ +--log-timers-to-tensorboard \ +--log-batch-size-to-tensorboard \ +--log-validation-ppl-to-tensorboard \ +" + +log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +cd ${CHATLEARN}/examples/megatron/sft + +torchrun $DISTRIBUTED_ARGS \ + finetune_sft.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/tests/run_policy_generation.sh b/examples/megatron/tests/run_policy_generation.sh index ae55720d..f88ed450 100644 --- a/examples/megatron/tests/run_policy_generation.sh +++ b/examples/megatron/tests/run_policy_generation.sh @@ -3,12 +3,16 @@ set -x [ -z "$MEGATRON" ] && export MEGATRON=path-to-megatron [ -z "$CHATLEARN" ] && export CHATLEARN=path-to-chatlearn -[ -z "$TP" ] && export TP=4 [ -z "$VOCAB_FILE" ] && export VOCAB_FILE=path-to-tokenizer [ -z "$LOAD" ] && export LOAD=path-to-ckpt [ -z "$DATASET_PATH" ] && export DATASET_PATH=path-to-dataset-json [ -z "$model_size" ] && export model_size=llama2-13B -[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-dir-of-tokenizer + +[ -z "$num_gpu" ] && export num_gpu=4 +[ -z "$TP" ] && export TP=4 +[ -z "$EP" ] && export EP=1 +[ -z "$PP" ] && export PP=1 cd $CHATLEARN/examples/megatron @@ -38,6 +42,14 @@ elif [[ $model_size == "llama2"* ]]; then configs=configs/llama2/test_vllm_policy.yaml fi export tokenizer_model=$VOCAB_FILE +elif [[ $model_size == "mixtral"* ]]; then + if [[ "$backend" == "megatron" ]]; then + configs=configs/mixtral/test_policy.yaml + else + echo "ERROR: mixtral model support megatron backend currently." + exit 1 + fi + export tokenizer_model=$VOCAB_FILE else echo "unexpected model_type $model_size." exit 1 @@ -54,9 +66,12 @@ log_file=${output_dir}/log_${RANK}.log export batch_generation_min_prompt_length=32 -generation_batch_size=64 \ -num_gpu=$TP \ +vllm_micro_batch_size=-1 \ +generation_batch_size=${generation_batch_size:-64} \ +num_gpu=$num_gpu \ policy_tp=$TP \ +policy_ep=$EP \ +policy_pp=$PP \ eval_data_path=$DATASET_PATH \ policy_inference_load=$LOAD \ python tests/test_policy_generation.py -c $configs 2>&1 | tee ${log_file}.log ; exit ${PIPESTATUS[0]} diff --git a/tests/run_tests.sh b/tests/run_tests.sh index 97f67d19..3b8d80b0 100644 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -65,6 +65,7 @@ function run_all_tests { run_test python test_rlhf_data_input.py -c "configs/exp.yaml" run_test python test_data_dp.py -c "configs/rlhf.yaml" run_test python test_data_dp_zero.py -c "configs/rlhf.yaml" + run_test python test_data_dp_ep.py -c "configs/rlhf.yaml" run_test python test_rlhf_colocate_forward_train.py -c "configs/rlhf2.yaml" run_test python test_evaluator_multi.py -c "configs/test_eval2.yaml" run_test python test_rlhf_cpu.py -c "configs/rlhf_cpu.yaml" @@ -106,6 +107,10 @@ if [ "$1" == "" ]; then else run_all_tests fi +elif [ "$1" == "test_data" ]; then + run_test python test_data_dp.py -c "configs/rlhf.yaml" + run_test python test_data_dp_zero.py -c "configs/rlhf.yaml" + run_test python test_data_dp_ep.py -c "configs/rlhf.yaml" elif [ "$1" == "test_fixed_data" ]; then run_test python test_fixed_data.py -c "configs/rlhf.yaml" elif [ "$1" == "test_dynamic_data" ]; then @@ -150,5 +155,3 @@ else echo -e "\033[31m$(date "+%Y-%m-%d %T.%N") [ERROR]: Unrecognized test name '$1'!\033[0m" exit -1 fi - -ray stop --force diff --git a/tests/test_data_dp_ep.py b/tests/test_data_dp_ep.py new file mode 100644 index 00000000..b6d1c8fd --- /dev/null +++ b/tests/test_data_dp_ep.py @@ -0,0 +1,219 @@ +import os +import time + +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +import chatlearn +from chatlearn import RLHFEngine +from chatlearn import TorchModule +from chatlearn.utils import future + + +class CustomDataset(Dataset): + def __init__(self, data): + self.data = data + self.collate_fn = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return {"query": self.data[idx]} + + +chatlearn.init() + +class PolicyModel(TorchModule): + + @property + def data_parallel_size(self): + return 8 + + @property + def data_parallel_rank(self): + return int(os.environ["RANK"]) + + def forward_step(self, data, iteration): + print("policy forward =========", flush=True) + query = data["query"] + bs = query.size(0) + data["policy_out"] = torch.ones([bs, 1024]).cuda() + return data + + def build_dataset(self, prompts, is_eval=False): + dataset = CustomDataset(prompts) + return dataset + + + +class ReferenceModel(TorchModule): + + @property + def data_parallel_size(self): + return 8 + + @property + def data_parallel_rank(self): + return int(os.environ["RANK"]) + + def forward_step(self, data, iteration): + print("reference forward =========", flush=True) + query = data["policy_out"].cuda() + data["ref_out"] = query * 2 + return data + + +class RewardModel(TorchModule): + + @property + def data_parallel_size(self): + return 8 + + @property + def data_parallel_rank(self): + return int(os.environ["RANK"]) + + def forward_step(self, data, iteration): + print("reward forward =========", flush=True) + data["reward_out"] = data["ref_out"].cuda() + data["policy_out"].cuda() + return data + +class ValueModel(TorchModule): + + @property + def data_parallel_size(self): + return 8 + + @property + def data_parallel_rank(self): + return int(os.environ["RANK"]) + + def forward_step(self, data, iteration): + print("value forward =========", flush=True) + data["value_out"] = data["policy_out"].cuda() * 3 + return data + + +class PPOPolicy(TorchModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.data = [] + + @property + def data_parallel_size(self): + return 8 + + @property + def data_parallel_rank(self): + return int(os.environ["RANK"]) + + def train_step(self, data, iteration): + print(f"ppo policy train_step ========= {self.data_parallel_rank}", flush=True) + self.data.append(data) + num_mb = len(data) + return num_mb + + def get_data(self): + return self.data + +class PPOValue(TorchModule): + + @property + def data_parallel_size(self): + return 8 + + @property + def data_parallel_rank(self): + return int(os.environ["RANK"]) + + def train_step(self, data, iteration): + print("ppo value train_step =========", flush=True) + num_mb = len(data) + return num_mb + +for _, model_config in chatlearn.get_args().models.items(): + model_config.num_gpu = 8 + +chatlearn.get_args().models['policy'].expert_model_parallel_size = 4 +chatlearn.get_args().models['reference'].expert_model_parallel_size = 4 +chatlearn.get_args().models['reward'].expert_model_parallel_size = 4 +chatlearn.get_args().models['value'].expert_model_parallel_size = 4 + +chatlearn.get_args().models['ppo_policy'].expert_model_parallel_size = 4 +chatlearn.get_args().models['ppo_value'].expert_model_parallel_size = 4 + +chatlearn.get_args().runtime_args.colocation = [["policy", "reference", "reward", "value", "ppo_policy", "ppo_value"]] +chatlearn.get_args().runtime_args.train_micro_batch_size = 4 +chatlearn.get_args().runtime_args.train_global_batch_size = 32 +chatlearn.get_args().runtime_args.generation_batch_size = 8 +chatlearn.get_args().runtime_args.max_relay_episode = 1 +chatlearn.get_args().runtime_args.sample_per_episode = 1024 +policy = PolicyModel("policy") +reference = ReferenceModel("reference") +reward = RewardModel("reward") +value = ValueModel("value") +ppo_policy = PPOPolicy("ppo_policy") +ppo_value = PPOValue("ppo_value") + +engine = RLHFEngine(policy, reference, reward, value, ppo_policy, ppo_value) + +def relay_sample_fn(episode_relay_buffers): + buffer = episode_relay_buffers[-1].buffer + episode_id = episode_relay_buffers[-1]._episode_id + assert len(buffer) == 1024 + for i in range(len(buffer)): + assert int(buffer[i]['query'][0].item()) == i + episode_id * 1024 + return buffer + +engine.set_relay_sample_fn(relay_sample_fn) +# for inference models, they have 2 dp replicas +assert policy.num_replica == 2 +assert reference.num_replica == 2 +assert reward.num_replica == 2 +assert value.num_replica == 2 +# for training models, ep is combined into dp, leading to only 1 replica +assert ppo_policy.num_replica == 1 +assert ppo_value.num_replica == 1 +data = [torch.ones([1024]) * i for i in range(2048)] +engine.set_dataset(data) +engine.learn() +assert engine.named_models['policy'].replicas[0].data_parallel_size == 8 +assert engine.named_models['reference'].replicas[0].data_parallel_size == 8 +assert engine.named_models['reward'].replicas[0].data_parallel_size == 8 +assert engine.named_models['value'].replicas[0].data_parallel_size == 8 +assert engine.named_models['ppo_policy'].replicas[0].data_parallel_size == 8 +assert engine.named_models['ppo_value'].replicas[0].data_parallel_size == 8 + +dp_rank_to_actors = engine.named_models['ppo_policy'].replicas[0].dp_rank_to_actors +assert len(dp_rank_to_actors) == 8 +assert len(dp_rank_to_actors[0]) == 1 +assert len(dp_rank_to_actors[1]) == 1 + +all_data = [] +for i in range(8): + data = future.get(dp_rank_to_actors[i][0].get_data.remote()) + for item in data: + for batch in item: + all_data.extend([i for i in batch['query'][:, 0].numpy()]) + +assert len(all_data) == 2048 +distinct_data = set(all_data) +assert len(distinct_data) == 2048 +assert min(distinct_data) == 0.0 +assert max(distinct_data) == 2047.0 + +dp_rank_to_actors = engine.named_models['ppo_value'].replicas[0].dp_rank_to_actors +assert len(dp_rank_to_actors) == 8 +assert len(dp_rank_to_actors[0]) == 1 +assert len(dp_rank_to_actors[1]) == 1 + +assert engine.env.batch_per_episode == 256 +assert engine.env.num_iteration == 64 +assert engine.trainer.batch_per_episode == 32 +assert engine.trainer.num_iteration == 32 +assert engine.trainer.num_micro_batch_per_dp == 1 + +assert len(engine.env._dataset) == 2048, len(engine.env._dataset) From b2e544ee7aef538aed8550d629143415505a13e2 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 20 Sep 2024 02:03:29 +0000 Subject: [PATCH 09/26] fix mixtral --- .../megatron/memory_manager/base_trainer.py | 6 ++-- .../megatron/configs/mixtral/test_policy.yaml | 25 ++++++++++++++ examples/megatron/scripts/base_env.sh | 19 +++++++++++ .../megatron/scripts/train_rlhf_mixtral.sh | 33 +++++++++++++++++-- 4 files changed, 79 insertions(+), 4 deletions(-) create mode 100644 examples/megatron/configs/mixtral/test_policy.yaml diff --git a/chatlearn/models/megatron/memory_manager/base_trainer.py b/chatlearn/models/megatron/memory_manager/base_trainer.py index 7ac921d7..9767b492 100644 --- a/chatlearn/models/megatron/memory_manager/base_trainer.py +++ b/chatlearn/models/megatron/memory_manager/base_trainer.py @@ -19,7 +19,6 @@ import torch -from chatlearn.models.megatron.memory_manager.base import BaseMemoryManager from chatlearn.utils.flat_tensors import BucketizedFlatTensors from chatlearn.utils.logger import log_rank_0 from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version @@ -126,7 +125,10 @@ def get_optimizer_list(self): def _optimizer_load_state_bucket_into_device(self, device, optimizer=None): """put the state bucket onto a device""" if optimizer is not None: - optimizer_list = [optimizer] + if isinstance(optimizer, ChainedOptimizer): + optimizer_list = optimizer.chained_optimizers + else: + optimizer_list = [optimizer] else: optimizer_list = self.get_optimizer_list() diff --git a/examples/megatron/configs/mixtral/test_policy.yaml b/examples/megatron/configs/mixtral/test_policy.yaml new file mode 100644 index 00000000..aa828c62 --- /dev/null +++ b/examples/megatron/configs/mixtral/test_policy.yaml @@ -0,0 +1,25 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu:1} + gpu_per_process: 1 + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + +runtime: + generation_batch_size: ${generation_batch_size:4} + data_path: ${data_path} + eval_data_path: ${eval_data_path} + output_dir: ${output_dir} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/scripts/base_env.sh b/examples/megatron/scripts/base_env.sh index 7cb2079d..b1c5a35e 100644 --- a/examples/megatron/scripts/base_env.sh +++ b/examples/megatron/scripts/base_env.sh @@ -147,6 +147,25 @@ elif [[ "$model_size" == "mixtral-8x7B" ]]; then export reward_moe_router_topk=2 export max_position_embedding=32768 export seq_length=2048 + export USE_LEGACY_MODELS=False +elif [[ "$model_size" == "mixtral-tiny" ]]; then + export policy_num_layers=4 + export policy_hidden_size=4096 + export policy_num_attention_heads=32 + export policy_num_query_groups=8 + export policy_ffn_hidden_size=14336 + export policy_num_experts=8 + export policy_moe_router_topk=2 + export reward_num_layers=4 + export reward_hidden_size=4096 + export reward_num_attention_heads=32 + export reward_num_query_groups=8 + export reward_ffn_hidden_size=14336 + export reward_num_experts=8 + export reward_moe_router_topk=2 + export max_position_embedding=32768 + export seq_length=2048 + export USE_LEGACY_MODELS=False else echo "unsupported model_size ${model_size}, please set your own model config" exit 1 diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh index 2abf5595..27b3279d 100644 --- a/examples/megatron/scripts/train_rlhf_mixtral.sh +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -52,8 +52,8 @@ if [[ "$model_size" == "mixtral-8x7B" ]]; then export policy_pp=4 export reward_pp=4 export value_pp=4 - export train_global_batch_size=128 - export generation_batch_size=1 + export train_global_batch_size=8 + export generation_batch_size=2 export ref_generation_batch_size=1 export train_micro_batch_size=1 export policy_recompute_activations=True @@ -66,6 +66,35 @@ if [[ "$model_size" == "mixtral-8x7B" ]]; then export free_memory_value=True export free_memory_ppo_policy=True export free_memory_ppo_value=True + export seq_length=2048 + export max_new_tokens=1024 +elif [[ "$model_size" == "mixtral-tiny" ]]; then + export policy_tp=1 + export policy_ep=4 + export ppo_policy_pp=2 + export reward_tp=1 + export reward_ep=4 + export ppo_value_pp=2 + export ref_pp=2 + export policy_pp=2 + export reward_pp=2 + export value_pp=2 + export train_global_batch_size=8 + export generation_batch_size=2 + export ref_generation_batch_size=2 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True + export value_recompute_activations=True + export value_moe_layer_recompute=True + # export free_memory_policy=True + # export free_memory_reference=True + # export free_memory_reward=True + # export free_memory_value=True + # export free_memory_ppo_policy=True + # export free_memory_ppo_value=True + export seq_length=2048 + export max_new_tokens=1024 fi mkdir -p ${output_dir} From 1aa2301c2f0cb9753af8b012434d02357c8d8a9a Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 20 Sep 2024 05:45:56 +0000 Subject: [PATCH 10/26] fix model mixtral --- examples/megatron/scripts/base_env.sh | 18 --------- .../scripts/convert_hf_to_megatron.sh | 4 +- .../megatron/scripts/train_reward_mixtral.sh | 12 +++--- .../megatron/scripts/train_rlhf_mixtral.sh | 37 +++---------------- examples/megatron/scripts/train_sft_llama.sh | 4 +- .../megatron/scripts/train_sft_mixtral.sh | 12 +++--- 6 files changed, 21 insertions(+), 66 deletions(-) diff --git a/examples/megatron/scripts/base_env.sh b/examples/megatron/scripts/base_env.sh index b1c5a35e..97d29d1c 100644 --- a/examples/megatron/scripts/base_env.sh +++ b/examples/megatron/scripts/base_env.sh @@ -148,24 +148,6 @@ elif [[ "$model_size" == "mixtral-8x7B" ]]; then export max_position_embedding=32768 export seq_length=2048 export USE_LEGACY_MODELS=False -elif [[ "$model_size" == "mixtral-tiny" ]]; then - export policy_num_layers=4 - export policy_hidden_size=4096 - export policy_num_attention_heads=32 - export policy_num_query_groups=8 - export policy_ffn_hidden_size=14336 - export policy_num_experts=8 - export policy_moe_router_topk=2 - export reward_num_layers=4 - export reward_hidden_size=4096 - export reward_num_attention_heads=32 - export reward_num_query_groups=8 - export reward_ffn_hidden_size=14336 - export reward_num_experts=8 - export reward_moe_router_topk=2 - export max_position_embedding=32768 - export seq_length=2048 - export USE_LEGACY_MODELS=False else echo "unsupported model_size ${model_size}, please set your own model config" exit 1 diff --git a/examples/megatron/scripts/convert_hf_to_megatron.sh b/examples/megatron/scripts/convert_hf_to_megatron.sh index 26c53e9e..50033997 100644 --- a/examples/megatron/scripts/convert_hf_to_megatron.sh +++ b/examples/megatron/scripts/convert_hf_to_megatron.sh @@ -20,9 +20,9 @@ load_dir=${LOAD_PATH} save_dir=${SAVE_PATH} tokenizer_model=${TOKENIZER_MODEL} if [[ $model == 'gpt_llama' ]]; then - model_size=${MODEL_SIZE:-llama2-7B} + model_size=${model_size:-llama2-7B} elif [[ $model == 'mixtral' ]]; then - model_size=${MODEL_SIZE:-mixtral-8x7B} + model_size=${model_size:-mixtral-8x7B} fi export CUDA_DEVICE_MAX_CONNECTIONS=1 diff --git a/examples/megatron/scripts/train_reward_mixtral.sh b/examples/megatron/scripts/train_reward_mixtral.sh index d311f6de..a4e9564c 100644 --- a/examples/megatron/scripts/train_reward_mixtral.sh +++ b/examples/megatron/scripts/train_reward_mixtral.sh @@ -24,9 +24,9 @@ DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ --master_addr ${MASTER_ADDR} \ --master_port ${MASTER_PORT}" -[ -z "$MODEL_SIZE" ] && export MODEL_SIZE="8x7B" +[ -z "$model_size" ] && export model_size="mixtral-8x7B" -if [ $MODEL_SIZE == "8x7B" ]; then +if [ $model_size == "mixtral-8x7B" ]; then NUM_LAYERS=32 HIDDEN_SIZE=4096 NUM_ATTN_HEADS=32 @@ -35,14 +35,14 @@ if [ $MODEL_SIZE == "8x7B" ]; then NUM_QUERY_GROUPS=8 NUM_EXPERTS=8 MOE_ROUTER_TOPK=2 - seq_len=2048 + seq_length=2048 tp=1 pp=4 ep=8 mb=1 gbs=64 else - echo "Unrecognized MODEL_SIZE ${MODEL_SIZE}, choose from '8x7B'." + echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." exit -1 fi @@ -58,13 +58,13 @@ dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp / $ep)) gbs=$(($gbs * $dp)) -[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/reward/mixtral_hh_reward_$(date +%F)_gpt_${MODEL_SIZE}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_len} +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/reward/mixtral_hh_reward_$(date +%F)_gpt_${model_size}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_length} mkdir -p $CHECKPOINT_PATH MODEL_ARGS=" --disable-bias-linear \ ---seq-length $seq_len \ +--seq-length $seq_length \ --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ --num-layers ${NUM_LAYERS} \ --hidden-size ${HIDDEN_SIZE} \ diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh index 27b3279d..362a65c0 100644 --- a/examples/megatron/scripts/train_rlhf_mixtral.sh +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -12,7 +12,7 @@ source ${DIR}/base_env.sh # TODO: support vllm backend=${1:-megatron} if [[ "$backend" != "megatron" ]]; then - echo "ERROR: expect megatron backend, while "$backend + echo "ERROR: expect megatron backend for Mixtral models, while current backend is "$backend exit 1 fi @@ -35,6 +35,7 @@ export trainer_engine=rlhf [ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} [ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ [ -z "$sample_per_episode" ] && sample_per_episode=1024 +[ -z "$num_episode" ] && num_episode=200 [ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend output_dir=${output_dir}/${exp_name} @@ -52,9 +53,9 @@ if [[ "$model_size" == "mixtral-8x7B" ]]; then export policy_pp=4 export reward_pp=4 export value_pp=4 - export train_global_batch_size=8 - export generation_batch_size=2 - export ref_generation_batch_size=1 + export train_global_batch_size=32 + export generation_batch_size=8 + export ref_generation_batch_size=8 export train_micro_batch_size=1 export policy_recompute_activations=True export policy_moe_layer_recompute=True @@ -68,33 +69,6 @@ if [[ "$model_size" == "mixtral-8x7B" ]]; then export free_memory_ppo_value=True export seq_length=2048 export max_new_tokens=1024 -elif [[ "$model_size" == "mixtral-tiny" ]]; then - export policy_tp=1 - export policy_ep=4 - export ppo_policy_pp=2 - export reward_tp=1 - export reward_ep=4 - export ppo_value_pp=2 - export ref_pp=2 - export policy_pp=2 - export reward_pp=2 - export value_pp=2 - export train_global_batch_size=8 - export generation_batch_size=2 - export ref_generation_batch_size=2 - export train_micro_batch_size=1 - export policy_recompute_activations=True - export policy_moe_layer_recompute=True - export value_recompute_activations=True - export value_moe_layer_recompute=True - # export free_memory_policy=True - # export free_memory_reference=True - # export free_memory_reward=True - # export free_memory_value=True - # export free_memory_ppo_policy=True - # export free_memory_ppo_value=True - export seq_length=2048 - export max_new_tokens=1024 fi mkdir -p ${output_dir} @@ -109,6 +83,7 @@ num_gpu=${num_gpu} \ data_path=${DATASET_PATH} \ eval_data_path=${EVAL_DATASET_PATH} \ sample_per_episode=${sample_per_episode} \ +num_episode=${num_episode} \ python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/scripts/train_sft_llama.sh b/examples/megatron/scripts/train_sft_llama.sh index 1144e4ce..e6a6daa7 100644 --- a/examples/megatron/scripts/train_sft_llama.sh +++ b/examples/megatron/scripts/train_sft_llama.sh @@ -86,7 +86,7 @@ dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp)) gbs=$(($gbs * $dp)) -[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/hh_sft_$(date +%F)_gpt_${MODEL_SIZE}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_mb${mb}_seqlen${seq_len} +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/hh_sft_$(date +%F)_gpt_${model_size}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_mb${mb}_seqlen${seq_len} mkdir -p $CHECKPOINT_PATH @@ -110,7 +110,7 @@ use_legacy_models=${USE_LEGACY_MODELS:-"True"} if [[ ${use_legacy_models} = "False" ]]; then MCORE_ARGS="--transformer-impl transformer_engine " else - if [[ ${MODEL_SIZE} = "llama3-8B" || ${MODEL_SIZE} = "llama3-70B" ]]; then + if [[ ${model_size} = "llama3-8B" || ${model_size} = "llama3-70B" ]]; then echo "Expect USE_LEGACY_MODELS to be False for Llama3 models, but got True." exit 1 fi diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh index c5418bf9..a1310d71 100644 --- a/examples/megatron/scripts/train_sft_mixtral.sh +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -23,9 +23,9 @@ DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHATLEARN} -[ -z "$MODEL_SIZE" ] && export MODEL_SIZE="8x7B" +[ -z "$model_size" ] && export model_size="mixtral-8x7B" -if [ $MODEL_SIZE == "8x7B" ]; then +if [ $model_size == "mixtral-8x7B" ]; then NUM_LAYERS=32 HIDDEN_SIZE=4096 NUM_ATTN_HEADS=32 @@ -34,14 +34,12 @@ if [ $MODEL_SIZE == "8x7B" ]; then NUM_QUERY_GROUPS=8 NUM_EXPERTS=8 MOE_ROUTER_TOPK=2 - seq_len=2048 + seq_length=2048 tp=1 pp=4 ep=8 mb=1 gbs=64 - echo "Unrecognized MODEL_SIZE ${MODEL_SIZE}, choose from '8x7B'." - exit -1 fi DIR=$(pwd) @@ -56,13 +54,13 @@ dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp / $ep)) gbs=$(($gbs * $dp)) -[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/mixtral_hh_sft_$(date +%F)_gpt_${MODEL_SIZE}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_len} +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/mixtral_hh_sft_$(date +%F)_gpt_${model_size}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_length} mkdir -p $CHECKPOINT_PATH MODEL_ARGS=" --disable-bias-linear \ ---seq-length $seq_len \ +--seq-length $seq_length \ --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ --num-layers ${NUM_LAYERS} \ --hidden-size ${HIDDEN_SIZE} \ From 19143f112665f432ca6eaa5852d3ce2f958f90dd Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 20 Sep 2024 14:22:54 +0800 Subject: [PATCH 11/26] align diff to main --- examples/megatron/configs/mixtral/dpo.yaml | 2 +- examples/megatron/configs/mixtral/online_dpo.yaml | 2 +- examples/megatron/configs/mixtral/rlhf.yaml | 2 +- examples/megatron/models/policy_model.py | 14 +++++--------- examples/megatron/models/reward_inference.py | 5 +++++ examples/megatron/scripts/train_rlhf_llama.sh | 4 ---- examples/megatron/scripts/train_rlhf_mixtral.sh | 2 -- examples/megatron/tests/run_policy_generation.sh | 2 +- tests/run_tests.sh | 2 ++ 9 files changed, 16 insertions(+), 19 deletions(-) diff --git a/examples/megatron/configs/mixtral/dpo.yaml b/examples/megatron/configs/mixtral/dpo.yaml index f6c751fe..d66150f1 100644 --- a/examples/megatron/configs/mixtral/dpo.yaml +++ b/examples/megatron/configs/mixtral/dpo.yaml @@ -32,7 +32,7 @@ runtime: - ppo_policy,reference train_micro_batch_size: ${train_micro_batch_size:2} train_global_batch_size: ${train_global_batch_size:512} - num_episode: ${num_episode:100} + num_episode: ${num_episode:200} sample_per_episode: ${sample_per_episode:1024} num_training_epoch: 1 save_episode_interval: ${save_episode_interval:100} diff --git a/examples/megatron/configs/mixtral/online_dpo.yaml b/examples/megatron/configs/mixtral/online_dpo.yaml index d818c10f..3da0ed80 100644 --- a/examples/megatron/configs/mixtral/online_dpo.yaml +++ b/examples/megatron/configs/mixtral/online_dpo.yaml @@ -48,7 +48,7 @@ runtime: generation_batch_size: ${generation_batch_size:4} train_micro_batch_size: ${train_micro_batch_size:2} train_global_batch_size: ${train_global_batch_size:512} - num_episode: ${num_episode:100} + num_episode: ${num_episode:200} sample_per_episode: ${sample_per_episode:1024} num_training_epoch: 1 save_episode_interval: ${save_episode_interval:100} diff --git a/examples/megatron/configs/mixtral/rlhf.yaml b/examples/megatron/configs/mixtral/rlhf.yaml index 5a693881..03e41be0 100644 --- a/examples/megatron/configs/mixtral/rlhf.yaml +++ b/examples/megatron/configs/mixtral/rlhf.yaml @@ -66,7 +66,7 @@ runtime: generation_batch_size: ${generation_batch_size:4} train_micro_batch_size: ${train_micro_batch_size:2} train_global_batch_size: ${train_global_batch_size:512} - num_episode: ${num_episode:100} + num_episode: ${num_episode:200} sample_per_episode: ${sample_per_episode:1024} num_training_epoch: 1 save_episode_interval: ${save_episode_interval:100} diff --git a/examples/megatron/models/policy_model.py b/examples/megatron/models/policy_model.py index 5bdc622d..0891d075 100644 --- a/examples/megatron/models/policy_model.py +++ b/examples/megatron/models/policy_model.py @@ -18,12 +18,13 @@ from megatron.training import get_args from megatron.core import tensor_parallel +from megatron.training.arguments import core_transformer_config_from_args from megatron.training.global_vars import get_tokenizer from megatron.legacy.model.gpt_model import GPTModel from megatron.legacy.model.language_model import parallel_lm_logits from chatlearn.models.megatron.ops.policy_gradient import tensor_decomp_pg_loss -from .utils import get_advantages_and_returns, has_config_in_args, get_eos_id +from .utils import get_advantages_and_returns, get_eos_id from .constants import TrainerEngine from .constants import select_actions_from_right_padded @@ -31,7 +32,6 @@ class PolicyModel(GPTModel): """PolicyModel""" - def __init__(self, num_tokentypes=0, parallel_output=True, @@ -39,13 +39,9 @@ def __init__(self, post_process=True, stats=None): self.args = get_args() - if has_config_in_args(GPTModel): - # new API - from megatron.training.arguments import core_transformer_config_from_args # pylint: disable=import-outside-toplevel - config = core_transformer_config_from_args(self.args) - super().__init__(config, num_tokentypes, parallel_output, pre_process, post_process) - else: - super().__init__(num_tokentypes, parallel_output, pre_process, post_process) + config = core_transformer_config_from_args(self.args) + super().__init__(config, num_tokentypes, parallel_output, pre_process, post_process) + self.tokenizer = get_tokenizer() self.stats = stats diff --git a/examples/megatron/models/reward_inference.py b/examples/megatron/models/reward_inference.py index b50dcfb6..3d3385b9 100644 --- a/examples/megatron/models/reward_inference.py +++ b/examples/megatron/models/reward_inference.py @@ -34,6 +34,7 @@ from chatlearn import MegatronModule from chatlearn.utils import to_device from chatlearn.utils.megatron_utils import load_checkpoint +from examples.megatron.data.reward_dataset import preprocess from .reward_model import RewardModel as LegacyRewardModel from .mcore_reward_model import MCoreRewardModel from .utils import tensorboard_scalar_dict, get_eos_id @@ -119,7 +120,11 @@ def setup(self): model = get_model(model_provider, wrap_with_ddp=False) if args.load: + torch.distributed.barrier() load_checkpoint(model, None, None, adaptive_parallel_strategy=args.adaptive_parallel_strategy_on_checkpoint) + torch.distributed.barrier() + else: + print_rank_0(f"Warning: Using random parameter for {self.name} model.") assert len(model) == 1, "Above condition should have caught this" self.model = model[0] diff --git a/examples/megatron/scripts/train_rlhf_llama.sh b/examples/megatron/scripts/train_rlhf_llama.sh index 665189e3..a5f57327 100644 --- a/examples/megatron/scripts/train_rlhf_llama.sh +++ b/examples/megatron/scripts/train_rlhf_llama.sh @@ -34,7 +34,6 @@ export trainer_engine=rlhf [ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} [ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ [ -z "$sample_per_episode" ] && sample_per_episode=1024 -[ -z "$num_episode" ] && num_episode=100 [ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend output_dir=${output_dir}/${exp_name} @@ -110,7 +109,4 @@ num_gpu=${num_gpu} \ data_path=${DATASET_PATH} \ eval_data_path=${EVAL_DATASET_PATH} \ sample_per_episode=${sample_per_episode} \ -num_episode=${num_episode} \ python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} - - diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh index 362a65c0..abaef84a 100644 --- a/examples/megatron/scripts/train_rlhf_mixtral.sh +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -35,7 +35,6 @@ export trainer_engine=rlhf [ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} [ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ [ -z "$sample_per_episode" ] && sample_per_episode=1024 -[ -z "$num_episode" ] && num_episode=200 [ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend output_dir=${output_dir}/${exp_name} @@ -83,7 +82,6 @@ num_gpu=${num_gpu} \ data_path=${DATASET_PATH} \ eval_data_path=${EVAL_DATASET_PATH} \ sample_per_episode=${sample_per_episode} \ -num_episode=${num_episode} \ python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/tests/run_policy_generation.sh b/examples/megatron/tests/run_policy_generation.sh index f88ed450..024631b3 100644 --- a/examples/megatron/tests/run_policy_generation.sh +++ b/examples/megatron/tests/run_policy_generation.sh @@ -7,7 +7,7 @@ set -x [ -z "$LOAD" ] && export LOAD=path-to-ckpt [ -z "$DATASET_PATH" ] && export DATASET_PATH=path-to-dataset-json [ -z "$model_size" ] && export model_size=llama2-13B -[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-dir-of-tokenizer +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend [ -z "$num_gpu" ] && export num_gpu=4 [ -z "$TP" ] && export TP=4 diff --git a/tests/run_tests.sh b/tests/run_tests.sh index 3b8d80b0..9bbbbc30 100644 --- a/tests/run_tests.sh +++ b/tests/run_tests.sh @@ -155,3 +155,5 @@ else echo -e "\033[31m$(date "+%Y-%m-%d %T.%N") [ERROR]: Unrecognized test name '$1'!\033[0m" exit -1 fi + +ray stop --force From 20e62728f2356c29034e5839900c6893cc2b90f4 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 20 Sep 2024 09:23:50 +0000 Subject: [PATCH 12/26] fix dpo and add test_checkpoint_conversion --- examples/megatron/models/forward_step.py | 2 +- .../scripts/train_online_dpo_mixtral.sh | 52 ++++++++++-------- .../tests/test_checkpoint_conversion.py | 54 +++++++++++++++++++ .../tests/test_checkpoint_conversion.sh | 49 +++++++++++++++++ 4 files changed, 134 insertions(+), 23 deletions(-) create mode 100644 examples/megatron/tests/test_checkpoint_conversion.py create mode 100644 examples/megatron/tests/test_checkpoint_conversion.sh diff --git a/examples/megatron/models/forward_step.py b/examples/megatron/models/forward_step.py index 0a772c05..8da905d6 100644 --- a/examples/megatron/models/forward_step.py +++ b/examples/megatron/models/forward_step.py @@ -151,7 +151,7 @@ def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, recv_buffer = None if inference_config_master is not None and "DPO_labels" in inference_config_master: for key, value in inference_config_master.items(): - inference_config[key] = value[start:end, ...] + inference_config[key] = value[start:end, ...] if value is not None else None output = _forward_step_helper(model, tokens2use, position_ids2use, attention_mask, recv_buffer=recv_buffer, pooling_sequence_index=pooling_sequence_index2use, diff --git a/examples/megatron/scripts/train_online_dpo_mixtral.sh b/examples/megatron/scripts/train_online_dpo_mixtral.sh index a8322190..06e83008 100644 --- a/examples/megatron/scripts/train_online_dpo_mixtral.sh +++ b/examples/megatron/scripts/train_online_dpo_mixtral.sh @@ -13,13 +13,31 @@ source ${DIR}/base_env.sh backend=${1:-megatron} if [[ "$backend" != "megatron" ]]; then - echo "ERROR: expect megatron backend, while "$backend + echo "ERROR: expect megatron backend for Mixtral models, while current backend is "$backend exit 1 fi +if [[ "$backend" == "megatron" ]]; then + configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo.yaml +else + export ENABLE_VLLM=True + if [ -z "$tokenizer_load" ];then + echo "please set path to hf tokenizer for vllm backend, download from huggingface source." + exit 1 + fi + configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo_vllm.yaml +fi export trainer_engine=online_dpo +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend + +output_dir=$output_dir/$exp_name +export data_checkpoint_path=${output_dir}/data_checkpoint + export train_to_compare_num_responses=8 export num_inference_per_prompt=8 @@ -35,34 +53,24 @@ if [[ "$model_size" == "mixtral-8x7B" ]]; then export policy_pp=4 export reward_pp=4 export value_pp=4 - export train_global_batch_size=128 - export generation_batch_size=128 - export ref_generation_batch_size=2 + export train_global_batch_size=32 + export generation_batch_size=8 + export ref_generation_batch_size=8 export train_micro_batch_size=1 export policy_recompute_activations=True export policy_moe_layer_recompute=True export value_recompute_activations=True export value_moe_layer_recompute=True + export free_memory_policy=True + export free_memory_reference=True + export free_memory_reward=True + export free_memory_value=True + export free_memory_ppo_policy=True + export free_memory_ppo_value=True + export seq_length=2048 + export max_new_tokens=1024 fi -if [[ "$backend" == "megatron" ]]; then - configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo.yaml -else - export ENABLE_VLLM=True - if [ -z "$tokenizer_load" ];then - echo "please set path to hf tokenizer for vllm backend, download from huggingface source." - exit 1 - fi - configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo_vllm.yaml -fi - -[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} -[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ -[ -z "$sample_per_episode" ] && sample_per_episode=1024 - -output_dir=$output_dir/$exp_name -export data_checkpoint_path=${output_dir}/data_checkpoint - mkdir -p ${output_dir} log_file=${output_dir}/log_${RANK}.log diff --git a/examples/megatron/tests/test_checkpoint_conversion.py b/examples/megatron/tests/test_checkpoint_conversion.py new file mode 100644 index 00000000..1625bd39 --- /dev/null +++ b/examples/megatron/tests/test_checkpoint_conversion.py @@ -0,0 +1,54 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""test checkpoint conversion between megatron (legacy or mcore) and huggingface""" + +import argparse + +import torch +from transformers import AutoModel + +def extract_name_and_params(model): + name_list = [] + param_list = [] + model_named_parameters = model.named_parameters() + for name, param in model_named_parameters: + name_list.append(name) + param_list.append(param) + return name_list, param_list + +def compare_checkpoint(src_path, dst_path): + src_model = AutoModel.from_pretrained(src_path) + dst_model = AutoModel.from_pretrained(dst_path) + src_model_names, src_model_params = extract_name_and_params(src_model) + dst_model_names, dst_model_params = extract_name_and_params(dst_model) + assert src_model_names == dst_model_names + for i, (src_param, dst_param) in enumerate(zip(src_model_params, dst_model_params)): + print(f"Comparing {src_model_names[i]}") + assert torch.equal(src_param, dst_param), f"Parameter {src_model_names[i]} is not equal for two models." + return True + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--src-path', type=str, required=True, + help='source huggingface checkpoint path') + parser.add_argument('--dst-path', type=str, required=True, + help='destinate hugginface checkpoint path') + args = parser.parse_args() + + return compare_checkpoint(args.src_path, args.dst_path) + +if __name__ == '__main__': + main() diff --git a/examples/megatron/tests/test_checkpoint_conversion.sh b/examples/megatron/tests/test_checkpoint_conversion.sh new file mode 100644 index 00000000..d45cf740 --- /dev/null +++ b/examples/megatron/tests/test_checkpoint_conversion.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -exo +set pipefail + +export CHATLEARN=${CHATLEARN:-"path-to-chatlearn"} +export MEGATRON=${MEGATRON:-"path-to-megatron-lm"} +export LOAD_PATH=${LOAD_PATH:-"path-to-hf-ckpt"} +export TEMP_PATH=${TEMP_PATH:-"path-to-converted-mg-ckpt"} +export SAVE_PATH=${SAVE_PATH:-"path-to-converted-back-hf-ckpt"} +export VOCAB_PATH=${VOCAB_PATH:-"path-to-vocabulary"} +export TOKENIZER_MODEL=${TOKENIZER_MODEL:-"path-to-tokenizer-model"} + +export MODEL=${MODEL:-"mixtral"} +export USE_LEGACY_MODELS=${USE_LEGACY_MODELS:-"False"} + +# Step 1: Convert to Megatron checkpoint + +cd $CHATLEARN/examples/megatron/ + +TP=1 \ +PP=4 \ +EP=8 \ +LOAD_PATH=${LOAD_PATH} \ +SAVE_PATH=${TEMP_PATH} \ +bash scripts/convert_hf_to_megatron.sh + +# Step 2: Convert to HuggingFace checkpoint + +LOAD_PATH=${TEMP_PATH} \ +SAVE_PATH=${SAVE_PATH} \ +VOCAB_PATH=${VOCAB_PATH} \ +target_params_dtype=bf16 \ +bash scripts/convert_megatron_to_hf.sh + +# Step 3: Compare converted hf ckpt against the original hf ckpt + +python3 tests/test_checkpoint_conversion.py \ + --src-path ${LOAD_PATH} \ + --dst-path ${SAVE_PATH} + +if [[ $? != 0 ]]; then + echo -e "\033[31m Unrecognized model ${model} \033[0m" + exit -1 +fi + +rm -rf ${TEMP_PATH} +rm -rf ${SAVE_PATH} + +echo "Test success!" From f68afd196633fbf1927c3606767344de74d0b7f1 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 20 Sep 2024 17:34:00 +0800 Subject: [PATCH 13/26] fix error msg --- examples/megatron/tests/test_checkpoint_conversion.sh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/megatron/tests/test_checkpoint_conversion.sh b/examples/megatron/tests/test_checkpoint_conversion.sh index d45cf740..62935c09 100644 --- a/examples/megatron/tests/test_checkpoint_conversion.sh +++ b/examples/megatron/tests/test_checkpoint_conversion.sh @@ -1,6 +1,5 @@ #!/bin/bash -set -exo -set pipefail +set -exo pipefail export CHATLEARN=${CHATLEARN:-"path-to-chatlearn"} export MEGATRON=${MEGATRON:-"path-to-megatron-lm"} @@ -39,7 +38,7 @@ python3 tests/test_checkpoint_conversion.py \ --dst-path ${SAVE_PATH} if [[ $? != 0 ]]; then - echo -e "\033[31m Unrecognized model ${model} \033[0m" + echo -e "\033[31m Test failed! \033[0m" exit -1 fi From d7152ac35a0c0737f14a38863e863cb392b67ede Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Mon, 30 Sep 2024 10:54:27 +0800 Subject: [PATCH 14/26] fix src_gpu to get_or_cache --- chatlearn/runtime/parameter_sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index 75c6225b..908f93ac 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -180,8 +180,8 @@ def add_recv_actor(self, src_rank, dst_rank): dst_actor = self.dst_model.get_actor(dst_rank) self.actor2rank[dst_actor] = dst_rank - src_gpu = future.get(src_actor.get_visible_gpus.remote()) - dst_gpu = future.get(dst_actor.get_visible_gpus.remote()) + src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") + dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) From a554cf69acc6f992826bf9b262fa4a4a113c9165 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Tue, 8 Oct 2024 16:08:40 +0800 Subject: [PATCH 15/26] revert "fix src_gpu to get_or_cache" --- chatlearn/runtime/parameter_sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index 908f93ac..75c6225b 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -180,8 +180,8 @@ def add_recv_actor(self, src_rank, dst_rank): dst_actor = self.dst_model.get_actor(dst_rank) self.actor2rank[dst_actor] = dst_rank - src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") - dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") + src_gpu = future.get(src_actor.get_visible_gpus.remote()) + dst_gpu = future.get(dst_actor.get_visible_gpus.remote()) src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) From caebecf397c9f923b4418c655227b2becca2f030 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Tue, 5 Nov 2024 08:28:11 +0000 Subject: [PATCH 16/26] fix comments --- examples/megatron/scripts/train_reward_mixtral.sh | 6 +++--- examples/megatron/scripts/train_sft_mixtral.sh | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/megatron/scripts/train_reward_mixtral.sh b/examples/megatron/scripts/train_reward_mixtral.sh index d311f6de..1265f6f6 100644 --- a/examples/megatron/scripts/train_reward_mixtral.sh +++ b/examples/megatron/scripts/train_reward_mixtral.sh @@ -54,7 +54,7 @@ NODE_RANK=$RANK NNODES=$WORLD_SIZE -dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp / $ep)) +dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp)) gbs=$(($gbs * $dp)) @@ -151,10 +151,10 @@ log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log export CUDA_DEVICE_MAX_CONNECTIONS=1 -cd ${CHATLEARN}/examples/megatron/alignment/reward +cd ${CHATLEARN}/examples/megatron/ torchrun $DISTRIBUTED_ARGS \ - finetune_reward.py \ + entry/train_reward.py \ ${MODEL_ARGS[@]} \ ${MOE_ARGS[@]} \ ${DATA_ARGS[@]} \ diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh index c5418bf9..dacc982b 100644 --- a/examples/megatron/scripts/train_sft_mixtral.sh +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -52,7 +52,7 @@ NODE_RANK=$RANK NNODES=$WORLD_SIZE -dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp / $ep)) +dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp)) gbs=$(($gbs * $dp)) @@ -147,10 +147,10 @@ log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log export CUDA_DEVICE_MAX_CONNECTIONS=1 -cd ${CHATLEARN}/examples/megatron/sft +cd ${CHATLEARN}/examples/megatron/ torchrun $DISTRIBUTED_ARGS \ - finetune_sft.py \ + entry/train_sft.py \ ${MODEL_ARGS[@]} \ ${MOE_ARGS[@]} \ ${DATA_ARGS[@]} \ From ef3025322d820c61f9ef0a382ab3e556152ae9bb Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 6 Nov 2024 08:42:24 +0000 Subject: [PATCH 17/26] fix import error --- chatlearn/models/megatron/memory_manager/base_trainer.py | 1 - chatlearn/models/megatron_module.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/models/megatron/memory_manager/base_trainer.py b/chatlearn/models/megatron/memory_manager/base_trainer.py index f4c5f6cb..ce9ce9be 100644 --- a/chatlearn/models/megatron/memory_manager/base_trainer.py +++ b/chatlearn/models/megatron/memory_manager/base_trainer.py @@ -19,7 +19,6 @@ import torch -from chatlearn.models.megatron.memory_manager.base import BaseMemoryManager from chatlearn.utils.flat_tensors import BucketizedFlatTensors from chatlearn.utils.logger import log_rank_0 from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version diff --git a/chatlearn/models/megatron_module.py b/chatlearn/models/megatron_module.py index e0b8dc2b..9ac14fc5 100644 --- a/chatlearn/models/megatron_module.py +++ b/chatlearn/models/megatron_module.py @@ -29,6 +29,7 @@ from chatlearn.models.megatron.memory_manager import create_trainer_memory_manager, InferenceMemoryManager except ImportError: mpu = None + print("Megatron is not imported, setting mpu to None.") from .torch_module import TorchModule From 7cc7aa93c6643d0d9c62c8662e9cc57724f3e336 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 6 Nov 2024 17:22:37 +0800 Subject: [PATCH 18/26] fix diff introduced in merge --- chatlearn/models/megatron_module.py | 2 +- chatlearn/runtime/parameter_sync.py | 17 +++++++---------- chatlearn/utils/arguments.py | 2 +- .../configs/mixtral/base_inference.yaml | 2 +- .../megatron/configs/mixtral/policy_shared.yaml | 1 - .../megatron/configs/mixtral/ppo_policy.yaml | 2 +- .../megatron/configs/mixtral/ppo_value.yaml | 2 +- .../megatron/configs/mixtral/reward_shared.yaml | 2 +- examples/megatron/models/policy_model.py | 1 + 9 files changed, 14 insertions(+), 17 deletions(-) diff --git a/chatlearn/models/megatron_module.py b/chatlearn/models/megatron_module.py index 9ac14fc5..8e349f36 100644 --- a/chatlearn/models/megatron_module.py +++ b/chatlearn/models/megatron_module.py @@ -28,8 +28,8 @@ from chatlearn.utils.megatron_utils import build_pipeline_layer_name_mapping from chatlearn.models.megatron.memory_manager import create_trainer_memory_manager, InferenceMemoryManager except ImportError: - mpu = None print("Megatron is not imported, setting mpu to None.") + mpu = None from .torch_module import TorchModule diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index c1154cb7..62613c0a 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -191,8 +191,8 @@ def add_recv_actor(self, src_rank, dst_rank): dst_actor = self.dst_model.get_actor(dst_rank) self.actor2rank[dst_actor] = dst_rank - src_gpu = future.get(src_actor.get_visible_gpus.remote()) - dst_gpu = future.get(dst_actor.get_visible_gpus.remote()) + src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") + dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) @@ -228,6 +228,7 @@ def add_recv_actor_stage2(self, src_rank, dst_rank): def build_rank_mapping(self, add_recv_actor_fn=None): # setup rank mapping for src parameter and dst parameter # get rank for one src_model, without model replicas + if add_recv_actor_fn is None: add_recv_actor_fn = self.add_recv_actor @@ -295,18 +296,18 @@ def build_rank_mapping_two_stage(self, add_recv_actor_fn=None): dst_ranks = self.dst_model.all_ranks local_src_ranks = future.get(self.src_model.replicas[0].get_local_param_ranks()) - if local_src_ranks[0] is None or dst_dp_ranks is None: + if local_src_ranks[0] is None or dst_ranks is None: if self._debug: logger.warning( - f"DEBUG MODE! src_dp_ranks {local_src_ranks} or dst_dp_ranks: {dst_dp_ranks} is None, " + f"DEBUG MODE! src_ranks {local_src_ranks} or dst_ranks: {dst_ranks} is None, " "make sure they have values in real application.") return else: - raise Exception(f"src_dp_ranks {local_src_ranks} or dst_dp_ranks {dst_dp_ranks} should not be None") + raise Exception(f"src_ranks {local_src_ranks} or dst_ranks {dst_ranks} should not be None") dp_rank_to_ranks = defaultdict(list) for local_ranks, dp_rank in local_src_ranks: dp_rank_to_ranks[dp_rank].append(local_ranks[dp_rank]) - src_dp_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())] + src_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())] replica_rank_iter = cycle(iter(src_ranks)) @@ -316,10 +317,6 @@ def build_rank_mapping_two_stage(self, add_recv_actor_fn=None): "currently we require mod value equals to zero for tensor_model_parallel_size of dst_model and that of src_model while " + \ f"src model {self.src_model.name}(TP={self.num_src_tensor_parallel}) and " + \ f"dst model {self.dst_model.name}(TP={self.num_dst_tensor_parallel})" - assert self.num_src_expert_parallel == self.num_dst_expert_parallel, \ - "currently we require the expert_model_parallel_size to be the same between " + \ - f"src model {self.src_model.name}(EP={self.num_src_expert_parallel}) and " + \ - f"dst model {self.dst_model.name}(EP={self.num_dst_expert_parallel})" assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0 def split_ranks_by_tp_and_ep_size(ranks, tp_size, ep_size): diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index 777cbf60..044dccd9 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -521,7 +521,7 @@ def _validate_params(self): if model_args.generation_batch_size is None or model_args.generation_batch_size <= 0: if self.runtime_args.generation_batch_size: model_args.generation_batch_size = self.runtime_args.generation_batch_size - for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "expert_model_parallel_size", "zero_size"]: + for key in ["pipeline_model_parallel_size", "tensor_model_parallel_size", "zero_size"]: if model_args.args_dict.get(key) is not None: setattr(model_args, key, model_args.args_dict.get(key)) assert getattr(model_args, key) >= 1 diff --git a/examples/megatron/configs/mixtral/base_inference.yaml b/examples/megatron/configs/mixtral/base_inference.yaml index a266dd7e..e8693e33 100644 --- a/examples/megatron/configs/mixtral/base_inference.yaml +++ b/examples/megatron/configs/mixtral/base_inference.yaml @@ -13,4 +13,4 @@ attention_dropout: 0.0 hidden_dropout: 0.0 retro_encoder_attention_dropout: 0.0 retro_encoder_hidden_dropout: 0.0 -inference_batch_times_seqlen_threshold: ${inference_batch_times_seqlen_threshold:4096} \ No newline at end of file +inference_batch_times_seqlen_threshold: ${inference_batch_times_seqlen_threshold:4096} diff --git a/examples/megatron/configs/mixtral/policy_shared.yaml b/examples/megatron/configs/mixtral/policy_shared.yaml index 09436f59..7b314f7b 100644 --- a/examples/megatron/configs/mixtral/policy_shared.yaml +++ b/examples/megatron/configs/mixtral/policy_shared.yaml @@ -11,4 +11,3 @@ expert_model_parallel_size: ${policy_ep:8} group_query_attention: ${group_query_attention:True} num_query_groups: ${policy_num_query_groups} use_distributed_optimizer: True - diff --git a/examples/megatron/configs/mixtral/ppo_policy.yaml b/examples/megatron/configs/mixtral/ppo_policy.yaml index ebb30b01..66e145eb 100644 --- a/examples/megatron/configs/mixtral/ppo_policy.yaml +++ b/examples/megatron/configs/mixtral/ppo_policy.yaml @@ -36,4 +36,4 @@ sequence_parallel: ${sequence_parallel:True} recompute_activations: ${policy_recompute_activations:False} recompute_granularity: ${policy_recompute_granularity:None} -moe_layer_recompute: ${policy_moe_layer_recompute:False} \ No newline at end of file +moe_layer_recompute: ${policy_moe_layer_recompute:False} diff --git a/examples/megatron/configs/mixtral/ppo_value.yaml b/examples/megatron/configs/mixtral/ppo_value.yaml index d524d30f..5424f897 100644 --- a/examples/megatron/configs/mixtral/ppo_value.yaml +++ b/examples/megatron/configs/mixtral/ppo_value.yaml @@ -27,4 +27,4 @@ sequence_parallel: True recompute_activations: ${value_recompute_activations:False} recompute_granularity: ${value_recompute_granularity:None} -moe_layer_recompute: ${value_moe_layer_recompute:False} \ No newline at end of file +moe_layer_recompute: ${value_moe_layer_recompute:False} diff --git a/examples/megatron/configs/mixtral/reward_shared.yaml b/examples/megatron/configs/mixtral/reward_shared.yaml index 325e45eb..a50ded56 100644 --- a/examples/megatron/configs/mixtral/reward_shared.yaml +++ b/examples/megatron/configs/mixtral/reward_shared.yaml @@ -15,4 +15,4 @@ group_query_attention: ${group_query_attention:True} num_query_groups: ${reward_num_query_groups} save_inference: True -save_inference_interval: 10 \ No newline at end of file +save_inference_interval: 10 diff --git a/examples/megatron/models/policy_model.py b/examples/megatron/models/policy_model.py index 0891d075..bebed6a3 100644 --- a/examples/megatron/models/policy_model.py +++ b/examples/megatron/models/policy_model.py @@ -32,6 +32,7 @@ class PolicyModel(GPTModel): """PolicyModel""" + def __init__(self, num_tokentypes=0, parallel_output=True, From c027414590c18135d60ab4fbfd7bf68a0565f33f Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Thu, 7 Nov 2024 02:05:49 +0000 Subject: [PATCH 19/26] expost import error earlier --- chatlearn/utils/megatron_import_helper.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/chatlearn/utils/megatron_import_helper.py b/chatlearn/utils/megatron_import_helper.py index 92e782bf..47f2fee0 100644 --- a/chatlearn/utils/megatron_import_helper.py +++ b/chatlearn/utils/megatron_import_helper.py @@ -216,13 +216,14 @@ reduce_scatter_to_sequence_parallel_region ) -# pylint: enable=unused-import +try: + from megatron.training import save_checkpoint_and_time +except ImportError: + from megatron.training.training import save_checkpoint_and_time def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): try: - from megatron.training import save_checkpoint_and_time as save_checkpoint_and_time_v1 # pylint: disable=import-outside-toplevel - save_checkpoint_and_time_v1(iteration, model, optimizer, opt_param_scheduler) - except ImportError: - from megatron.training.training import save_checkpoint_and_time as save_checkpoint_and_time_v2# pylint: disable=import-outside-toplevel - save_checkpoint_and_time_v2(iteration, model, optimizer, opt_param_scheduler, 0, None) + save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) + except TypeError: # missing required positional arguments + save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, 0, None) From 77f423036b0ad5aadc37ac3de8c9f8ef4f9616f3 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Thu, 7 Nov 2024 02:06:22 +0000 Subject: [PATCH 20/26] fix merge error --- chatlearn/models/megatron/memory_manager/base_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chatlearn/models/megatron/memory_manager/base_trainer.py b/chatlearn/models/megatron/memory_manager/base_trainer.py index 3568deb7..2e5ed3ea 100644 --- a/chatlearn/models/megatron/memory_manager/base_trainer.py +++ b/chatlearn/models/megatron/memory_manager/base_trainer.py @@ -145,7 +145,6 @@ def _optimizer_load_state_bucket_into_device(self, device, optimizer=None): # compatible with transformer_engine v1.10, state['master_param']=None if tensors[key] is not None: tensors[key] = tensors[key].to(device=device, non_blocking=True) - tensors[key] = tensors[key].to(device=device, non_blocking=True) # make sure the loading is finished before returning torch.cuda.synchronize() From 289628509d88f36f4405898eddb068857f200e74 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Thu, 7 Nov 2024 03:10:16 +0000 Subject: [PATCH 21/26] fix recursion error --- chatlearn/utils/megatron_import_helper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/chatlearn/utils/megatron_import_helper.py b/chatlearn/utils/megatron_import_helper.py index 47f2fee0..c369d02d 100644 --- a/chatlearn/utils/megatron_import_helper.py +++ b/chatlearn/utils/megatron_import_helper.py @@ -217,13 +217,13 @@ ) try: - from megatron.training import save_checkpoint_and_time + from megatron.training import save_checkpoint_and_time as megatron_save_checkpoint_and_time except ImportError: - from megatron.training.training import save_checkpoint_and_time + from megatron.training.training import save_checkpoint_and_time as megatron_save_checkpoint_and_time def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): try: - save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) - except TypeError: # missing required positional arguments - save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, 0, None) + megatron_save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) + except TypeError: # missing required positional arguments for new Megatron version + megatron_save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, 0, None) From 333e66195f88dbd674e433c964579a441afba1f8 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Thu, 7 Nov 2024 03:10:53 +0000 Subject: [PATCH 22/26] add validate_param_sync option to mixtral models --- examples/megatron/configs/mixtral/online_dpo.yaml | 1 + examples/megatron/configs/mixtral/rlhf.yaml | 2 ++ 2 files changed, 3 insertions(+) diff --git a/examples/megatron/configs/mixtral/online_dpo.yaml b/examples/megatron/configs/mixtral/online_dpo.yaml index 3da0ed80..1d8cb509 100644 --- a/examples/megatron/configs/mixtral/online_dpo.yaml +++ b/examples/megatron/configs/mixtral/online_dpo.yaml @@ -60,3 +60,4 @@ runtime: output_dir: ${output_dir} free_sync_collective_group: ${free_sync_collective_group:False} exp_name: ${exp_name:chatlearn} + validate_param_sync: ${validate_param_sync:False} diff --git a/examples/megatron/configs/mixtral/rlhf.yaml b/examples/megatron/configs/mixtral/rlhf.yaml index 03e41be0..dbbfb17d 100644 --- a/examples/megatron/configs/mixtral/rlhf.yaml +++ b/examples/megatron/configs/mixtral/rlhf.yaml @@ -79,3 +79,5 @@ runtime: free_sync_collective_group: ${free_sync_collective_group:False} exp_name: ${exp_name:chatlearn} output_dir: ${output_dir} + debug: ${debug:False} + validate_param_sync: ${validate_param_sync:False} From d86136f7b4afdef6c3d818867baedd9028be47e7 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Thu, 7 Nov 2024 03:26:52 +0000 Subject: [PATCH 23/26] fix redundant empty lines --- examples/megatron/scripts/train_rlhf_mixtral.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh index abaef84a..722061c4 100644 --- a/examples/megatron/scripts/train_rlhf_mixtral.sh +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -83,5 +83,3 @@ data_path=${DATASET_PATH} \ eval_data_path=${EVAL_DATASET_PATH} \ sample_per_episode=${sample_per_episode} \ python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} - - From cd00af076ed67b1c91363d581d2ccb5a3810f6af Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 8 Nov 2024 06:22:07 +0000 Subject: [PATCH 24/26] fix scripts --- .../scripts/convert_hf_to_megatron.sh | 2 +- .../megatron/scripts/train_reward_mixtral.sh | 48 ++++++++++++------- .../megatron/scripts/train_rlhf_mixtral.sh | 3 ++ .../megatron/scripts/train_sft_mixtral.sh | 25 +++++----- 4 files changed, 49 insertions(+), 29 deletions(-) diff --git a/examples/megatron/scripts/convert_hf_to_megatron.sh b/examples/megatron/scripts/convert_hf_to_megatron.sh index 43cad2fb..23354139 100644 --- a/examples/megatron/scripts/convert_hf_to_megatron.sh +++ b/examples/megatron/scripts/convert_hf_to_megatron.sh @@ -58,7 +58,7 @@ elif [[ ${model} == 'mixtral' ]]; then cd ${megatron} python tools/checkpoint/convert.py \ --model-type GPT \ - --loader loader_mixtral_hf \ + --loader mixtral_hf \ --saver mcore \ --target-tensor-parallel-size ${tp} \ --target-pipeline-parallel-size ${pp} \ diff --git a/examples/megatron/scripts/train_reward_mixtral.sh b/examples/megatron/scripts/train_reward_mixtral.sh index ba381a09..9c124c8c 100644 --- a/examples/megatron/scripts/train_reward_mixtral.sh +++ b/examples/megatron/scripts/train_reward_mixtral.sh @@ -7,6 +7,11 @@ set -x [ -z "$RANK" ] && export RANK=0 [ -z "$MASTER_PORT" ] && export MASTER_PORT=12456 +DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ + --nnodes ${WORLD_SIZE} \ + --node_rank ${RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT}" # check the path [[ -z "${MEGATRON}" ]] && { echo "MEGATRON path is not set"; exit 1; } @@ -16,17 +21,11 @@ set -x [[ -z "${DATASET_PATH}" ]] && { echo "DATASET_PATH is not set"; exit 1; } -export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}:${CHATLEARN}/examples/megatron - -DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ - --nnodes ${WORLD_SIZE} \ - --node_rank ${RANK} \ - --master_addr ${MASTER_ADDR} \ - --master_port ${MASTER_PORT}" +export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHATLEARN} [ -z "$model_size" ] && export model_size="mixtral-8x7B" -if [ $model_size == "mixtral-8x7B" ]; then +if [[ $model_size == "mixtral-8x7B" ]]; then NUM_LAYERS=32 HIDDEN_SIZE=4096 NUM_ATTN_HEADS=32 @@ -40,6 +39,21 @@ if [ $model_size == "mixtral-8x7B" ]; then pp=4 ep=8 mb=1 + gbs=32 +elif [[ $model_size == "mixtral-tiny" ]]; then + NUM_LAYERS=2 + HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=32 + FFN_HIDDEN_SIZE=14336 + MAX_POSITION_EMBEDDINGS=32768 + NUM_QUERY_GROUPS=8 + NUM_EXPERTS=8 + MOE_ROUTER_TOPK=2 + seq_length=2048 + tp=1 + pp=2 + ep=4 + mb=1 gbs=64 else echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." @@ -70,7 +84,7 @@ MODEL_ARGS=" --hidden-size ${HIDDEN_SIZE} \ --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ --num-attention-heads ${NUM_ATTN_HEADS} \ ---init-method-std 0.006 \ +--init-method-std 0.01 \ --attention-dropout 0.0 \ --hidden-dropout 0.0 \ --normalization RMSNorm \ @@ -104,13 +118,13 @@ DATA_ARGS=" TRAINING_ARGS=" --micro-batch-size $mb \ --global-batch-size $gbs \ ---lr 2.0e-5 \ +--lr 1e-4 \ --train-iters 1000 \ ---lr-decay-iters 1000 \ +--lr-decay-iters 640 \ --lr-decay-style cosine \ ---min-lr 6.0e-12 \ ---weight-decay 0. \ ---lr-warmup-iters 40 \ +--min-lr 1.0e-5 \ +--weight-decay 0.1 \ +--lr-warmup-iters 50 \ --clip-grad 1.0 \ --bf16 \ --exit-on-missing-checkpoint \ @@ -137,13 +151,13 @@ LOGGING_ARGS=" --save-interval 1000 \ --save $CHECKPOINT_PATH \ --load $LOAD_PATH \ ---tensorboard-dir $CHECKPOINT_PATH \ ---tensorboard-log-interval 100 \ +--auto-detect-ckpt-format \ --num-workers 8 \ --no-load-rng \ --no-load-optim \ +--tensorboard-dir $CHECKPOINT_PATH \ +--tensorboard-log-interval 10 \ --log-timers-to-tensorboard \ ---log-batch-size-to-tensorboard \ --log-validation-ppl-to-tensorboard \ " diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh index 722061c4..49f31f8b 100644 --- a/examples/megatron/scripts/train_rlhf_mixtral.sh +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -68,6 +68,9 @@ if [[ "$model_size" == "mixtral-8x7B" ]]; then export free_memory_ppo_value=True export seq_length=2048 export max_new_tokens=1024 +else + echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." + exit -1 fi mkdir -p ${output_dir} diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh index d7a96ce6..61e05dc9 100644 --- a/examples/megatron/scripts/train_sft_mixtral.sh +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -25,7 +25,7 @@ export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHA [ -z "$model_size" ] && export model_size="mixtral-8x7B" -if [ $model_size == "mixtral-8x7B" ]; then +if [[ $model_size == "mixtral-8x7B" ]]; then NUM_LAYERS=32 HIDDEN_SIZE=4096 NUM_ATTN_HEADS=32 @@ -39,7 +39,10 @@ if [ $model_size == "mixtral-8x7B" ]; then pp=4 ep=8 mb=1 - gbs=64 + gbs=32 +else + echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." + exit -1 fi DIR=$(pwd) @@ -66,7 +69,7 @@ MODEL_ARGS=" --hidden-size ${HIDDEN_SIZE} \ --ffn-hidden-size ${FFN_HIDDEN_SIZE} \ --num-attention-heads ${NUM_ATTN_HEADS} \ ---init-method-std 0.006 \ +--init-method-std 0.01 \ --attention-dropout 0.0 \ --hidden-dropout 0.0 \ --normalization RMSNorm \ @@ -99,13 +102,13 @@ DATA_ARGS=" TRAINING_ARGS=" --micro-batch-size $mb \ --global-batch-size $gbs \ ---lr 2.0e-5 \ +--lr 1e-4 \ --train-iters 1000 \ ---lr-decay-iters 1000 \ +--lr-decay-iters 640 \ --lr-decay-style cosine \ ---min-lr 6.0e-12 \ ---weight-decay 0. \ ---lr-warmup-iters 40 \ +--min-lr 1.0e-5 \ +--weight-decay 0.1 \ +--lr-warmup-iters 50 \ --clip-grad 1.0 \ --bf16 \ --exit-on-missing-checkpoint \ @@ -131,13 +134,13 @@ LOGGING_ARGS=" --save-interval 1000 \ --save $CHECKPOINT_PATH \ --load $LOAD_PATH \ ---tensorboard-dir $CHECKPOINT_PATH \ ---tensorboard-log-interval 100 \ +--auto-detect-ckpt-format \ --num-workers 8 \ --no-load-rng \ --no-load-optim \ +--tensorboard-dir $CHECKPOINT_PATH \ +--tensorboard-log-interval 10 \ --log-timers-to-tensorboard \ ---log-batch-size-to-tensorboard \ --log-validation-ppl-to-tensorboard \ " From 7ed2eff612bfbc02093239dfa69af36ca181ccbd Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Mon, 11 Nov 2024 06:37:13 +0000 Subject: [PATCH 25/26] mixtral sft ok --- examples/megatron/scripts/train_sft_mixtral.sh | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh index 61e05dc9..f4348c23 100644 --- a/examples/megatron/scripts/train_sft_mixtral.sh +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -39,7 +39,7 @@ if [[ $model_size == "mixtral-8x7B" ]]; then pp=4 ep=8 mb=1 - gbs=32 + gbs=8 else echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." exit -1 @@ -102,13 +102,13 @@ DATA_ARGS=" TRAINING_ARGS=" --micro-batch-size $mb \ --global-batch-size $gbs \ ---lr 1e-4 \ +--lr 5e-6 \ --train-iters 1000 \ ---lr-decay-iters 640 \ +--lr-decay-iters 1000 \ --lr-decay-style cosine \ ---min-lr 1.0e-5 \ ---weight-decay 0.1 \ ---lr-warmup-iters 50 \ +--min-lr 1.0e-12 \ +--weight-decay 0. \ +--lr-warmup-iters 40 \ --clip-grad 1.0 \ --bf16 \ --exit-on-missing-checkpoint \ @@ -141,6 +141,7 @@ LOGGING_ARGS=" --tensorboard-dir $CHECKPOINT_PATH \ --tensorboard-log-interval 10 \ --log-timers-to-tensorboard \ +--log-batch-size-to-tensorboard \ --log-validation-ppl-to-tensorboard \ " From f23ac372b8fbc7238ace7d714155c3395fa1855d Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Tue, 12 Nov 2024 06:05:25 +0000 Subject: [PATCH 26/26] fix sft --- .../megatron/scripts/train_reward_mixtral.sh | 36 +++++++------------ .../megatron/scripts/train_sft_mixtral.sh | 10 +++--- 2 files changed, 17 insertions(+), 29 deletions(-) diff --git a/examples/megatron/scripts/train_reward_mixtral.sh b/examples/megatron/scripts/train_reward_mixtral.sh index 9c124c8c..3694cd80 100644 --- a/examples/megatron/scripts/train_reward_mixtral.sh +++ b/examples/megatron/scripts/train_reward_mixtral.sh @@ -39,22 +39,7 @@ if [[ $model_size == "mixtral-8x7B" ]]; then pp=4 ep=8 mb=1 - gbs=32 -elif [[ $model_size == "mixtral-tiny" ]]; then - NUM_LAYERS=2 - HIDDEN_SIZE=4096 - NUM_ATTN_HEADS=32 - FFN_HIDDEN_SIZE=14336 - MAX_POSITION_EMBEDDINGS=32768 - NUM_QUERY_GROUPS=8 - NUM_EXPERTS=8 - MOE_ROUTER_TOPK=2 - seq_length=2048 - tp=1 - pp=2 - ep=4 - mb=1 - gbs=64 + gbs=8 else echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." exit -1 @@ -111,29 +96,31 @@ MOE_ARGS=" DATA_ARGS=" --tokenizer-type Llama2Tokenizer \ --tokenizer-model ${TOKENIZER_MODEL} \ ---data-path $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl \ +--data-path $DATASET_PATH/train.jsonl $DATASET_PATH/dev.jsonl $DATASET_PATH/dev.jsonl \ --split 98,2,0 \ --dataloader-type cyclic " TRAINING_ARGS=" --micro-batch-size $mb \ --global-batch-size $gbs \ ---lr 1e-4 \ +--lr 3e-6 \ --train-iters 1000 \ ---lr-decay-iters 640 \ +--lr-decay-iters 1000 \ --lr-decay-style cosine \ ---min-lr 1.0e-5 \ +--min-lr 1.0e-12 \ --weight-decay 0.1 \ ---lr-warmup-iters 50 \ +--lr-warmup-iters 300 \ --clip-grad 1.0 \ --bf16 \ --exit-on-missing-checkpoint \ --use-checkpoint-args \ --adam-beta1 0.9 \ ---adam-beta2 0.999 \ +--adam-beta2 0.95 \ --use-flash-attn \ --finetune \ ---recompute-activations" +--recompute-activations \ +--max-response 2 \ +--select-max-response firstk " MODEL_PARALLEL_ARGS=" --tensor-model-parallel-size $tp \ @@ -146,7 +133,7 @@ MODEL_PARALLEL_ARGS=" LOGGING_ARGS=" --log-interval 1 \ ---eval-iters 10 \ +--eval-iters 20 \ --eval-interval 1000 \ --save-interval 1000 \ --save $CHECKPOINT_PATH \ @@ -158,6 +145,7 @@ LOGGING_ARGS=" --tensorboard-dir $CHECKPOINT_PATH \ --tensorboard-log-interval 10 \ --log-timers-to-tensorboard \ +--log-batch-size-to-tensorboard \ --log-validation-ppl-to-tensorboard \ " diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh index f4348c23..e3eb7fcc 100644 --- a/examples/megatron/scripts/train_sft_mixtral.sh +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -102,13 +102,13 @@ DATA_ARGS=" TRAINING_ARGS=" --micro-batch-size $mb \ --global-batch-size $gbs \ ---lr 5e-6 \ +--lr 1e-5 \ --train-iters 1000 \ --lr-decay-iters 1000 \ --lr-decay-style cosine \ ---min-lr 1.0e-12 \ ---weight-decay 0. \ ---lr-warmup-iters 40 \ +--min-lr 1.0e-7 \ +--weight-decay 0.01 \ +--lr-warmup-iters 50 \ --clip-grad 1.0 \ --bf16 \ --exit-on-missing-checkpoint \ @@ -130,7 +130,7 @@ MODEL_PARALLEL_ARGS=" LOGGING_ARGS=" --log-interval 1 \ --eval-iters 10 \ ---eval-interval 1000 \ +--eval-interval 100 \ --save-interval 1000 \ --save $CHECKPOINT_PATH \ --load $LOAD_PATH \