diff --git a/chatlearn/models/megatron/memory_manager/base_trainer.py b/chatlearn/models/megatron/memory_manager/base_trainer.py index c85db852..2e5ed3ea 100644 --- a/chatlearn/models/megatron/memory_manager/base_trainer.py +++ b/chatlearn/models/megatron/memory_manager/base_trainer.py @@ -27,6 +27,7 @@ MixedPrecisionOptimizer, DistributedOptimizer, Float16OptimizerWithFloat16Params, + ChainedOptimizer, ) @@ -92,34 +93,58 @@ def __init__( self._use_distributed_optimizer = use_distributed_optimizer self._bucket_size_mb = bucket_size_mb + def sanity_check(single_optimizer): + assert isinstance( + single_optimizer, (MixedPrecisionOptimizer,) + ), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.' + + if self._use_distributed_optimizer: + assert isinstance(single_optimizer, DistributedOptimizer) + else: + log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params') + assert isinstance(single_optimizer, Float16OptimizerWithFloat16Params) + assert isinstance( model, (DistributedDataParallel,) ), f'Only support model type DistributedDataParallel, current type is {str(type(model))}.' - assert isinstance( - optimizer, (MixedPrecisionOptimizer,) - ), f'Only support optimizer type MixedPrecisionOptimizer and its subclasses, current type is {str(type(optimizer))}.' - - # sanity check - if self._use_distributed_optimizer: - assert isinstance(optimizer, DistributedOptimizer) + if isinstance(optimizer, ChainedOptimizer): + for single_optimizer in optimizer.chained_optimizers: + sanity_check(single_optimizer) + self._is_chained_optimizer = True else: - log_rank_0('Current optimizer is Float16OptimizerWithFloat16Params') - assert isinstance(optimizer, Float16OptimizerWithFloat16Params) + sanity_check(optimizer) + self._is_chained_optimizer = False self._main_weights_offloaded = False self._group_flat_main_weights: Optional[List[BucketizedFlatTensors]] = None self._megatron_version = get_megatron_version() - def _optimizer_load_state_bucket_into_device(self, device): + def get_optimizer_list(self): + if self._is_chained_optimizer: + optimizer_list = self._optimizer.chained_optimizers + else: + optimizer_list = [self._optimizer] + return optimizer_list + + def _optimizer_load_state_bucket_into_device(self, device, optimizer=None): """put the state bucket onto a device""" - state_dict = self._optimizer.optimizer.state_dict() - for tensors in state_dict['state'].values(): - keys = list(tensors.keys()) - for key in keys: - # compatible with transformer_engine v1.10, state['master_param']=None - if tensors[key] is not None: - tensors[key] = tensors[key].to(device=device, non_blocking=True) + if optimizer is not None: + if isinstance(optimizer, ChainedOptimizer): + optimizer_list = optimizer.chained_optimizers + else: + optimizer_list = [optimizer] + else: + optimizer_list = self.get_optimizer_list() + + for single_optimizer in optimizer_list: + state_dict = single_optimizer.optimizer.state_dict() + for tensors in state_dict['state'].values(): + keys = list(tensors.keys()) + for key in keys: + # compatible with transformer_engine v1.10, state['master_param']=None + if tensors[key] is not None: + tensors[key] = tensors[key].to(device=device, non_blocking=True) # make sure the loading is finished before returning torch.cuda.synchronize() @@ -154,12 +179,16 @@ def offload_main_weights(self): return if self._group_flat_main_weights is None: - if self._use_distributed_optimizer: - self._group_flat_main_weights = self._flat_param_groups( - [self._optimizer.shard_fp32_from_float16_groups] - ) - else: - self._group_flat_main_weights = self._flat_param_groups([self._optimizer.fp32_from_float16_groups]) + self._group_flat_main_weights = [] + optimizer_list = self.get_optimizer_list() + + for optimizer in optimizer_list: + if self._use_distributed_optimizer: + self._group_flat_main_weights.extend(self._flat_param_groups( + [optimizer.shard_fp32_from_float16_groups] + )) + else: + self._group_flat_main_weights.extend(self._flat_param_groups([optimizer.fp32_from_float16_groups])) for flat_main_weights in self._group_flat_main_weights: flat_main_weights.copy_to_primary_store() diff --git a/chatlearn/models/megatron/memory_manager/trainer_v1v2.py b/chatlearn/models/megatron/memory_manager/trainer_v1v2.py index 83f86e7c..5041c6d3 100644 --- a/chatlearn/models/megatron/memory_manager/trainer_v1v2.py +++ b/chatlearn/models/megatron/memory_manager/trainer_v1v2.py @@ -87,27 +87,31 @@ def offload_weights(self): log_rank_0('Call offload_weights when already offloaded. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() if self._use_distributed_optimizer: - optimizer.shard_float16_groups.clear() - optimizer.shard_fp32_groups.clear() + for optimizer in optimizer_list: + optimizer.shard_float16_groups.clear() + optimizer.shard_fp32_groups.clear() if self._group_flat_weights is None: - if self._use_distributed_optimizer: - self._group_flat_weights = self._flat_param_groups( - [ - optimizer.model_float16_groups, - optimizer.model_fp32_groups, - ], - ) - else: - self._group_flat_weights = self._flat_param_groups( - [ - optimizer.float16_groups, - optimizer.fp32_from_fp32_groups, - ], - ) + self._group_flat_weights = [] + + for optimizer in optimizer_list: + if self._use_distributed_optimizer: + self._group_flat_weights.extend(self._flat_param_groups( + [ + optimizer.model_float16_groups, + optimizer.model_fp32_groups, + ], + )) + else: + self._group_flat_weights.extend(self._flat_param_groups( + [ + optimizer.float16_groups, + optimizer.fp32_from_fp32_groups, + ], + )) for flat_weights in self._group_flat_weights: flat_weights.copy_to_primary_store() @@ -124,7 +128,7 @@ def onload_weights(self): log_rank_0('Call onload_weights when already onloaded. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() for flat_weights in self._group_flat_weights: flat_weights.copy_to_gpu_buffer() @@ -148,55 +152,56 @@ def onload_weights(self): self._weights_offloaded = False return - shard_float16_groups = optimizer.shard_float16_groups - shard_fp32_groups = optimizer.shard_fp32_groups - param_gbuf_map = optimizer.model_param_gbuf_map - opt_group_ranges = optimizer.opt_group_ranges - model_gbuf_ranges = optimizer.model_gbuf_ranges - - # Rebuild shard_float16_groups and shard_fp32_groups, - # see Megatron DistributedOptimizer#build_model_and_main_param_groups. - for _, group_range in enumerate(opt_group_ranges): - shard_float16_params_this_group = [] - shard_fp32_params_this_group = [] - shard_float16_groups.append(shard_float16_params_this_group) - shard_fp32_groups.append(shard_fp32_params_this_group) - - for model_param in group_range["params"]: - assert model_param.requires_grad - if self._megatron_version == MegatronVersion.V2: - model_index, dtype, bucket_index = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index] - param_range = gbuf_range["param_map"][model_param]["param"] - elif self._megatron_version == MegatronVersion.V1: - model_index, dtype = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[model_index][dtype] - param_range = gbuf_range["param_map"][model_param]["param"] - - # fp16, bf16 params. - if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: - shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - - shard_float16_params_this_group.append(shard_model_param) - - # fp32 params. - elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1)[param_range.start : param_range.end] - shard_fp32_params_this_group.append(shard_model_param) - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - else: - raise TypeError( - 'Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(model_param.type()) - ) + for optimizer in optimizer_list: + shard_float16_groups = optimizer.shard_float16_groups + shard_fp32_groups = optimizer.shard_fp32_groups + param_gbuf_map = optimizer.model_param_gbuf_map + opt_group_ranges = optimizer.opt_group_ranges + model_gbuf_ranges = optimizer.model_gbuf_ranges + + # Rebuild shard_float16_groups and shard_fp32_groups, + # see Megatron DistributedOptimizer#build_model_and_main_param_groups. + for _, group_range in enumerate(opt_group_ranges): + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + + for model_param in group_range["params"]: + assert model_param.requires_grad + if self._megatron_version == MegatronVersion.V2: + model_index, dtype, bucket_index = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype][bucket_index] + param_range = gbuf_range["param_map"][model_param]["param"] + elif self._megatron_version == MegatronVersion.V1: + model_index, dtype = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[model_index][dtype] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: + shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + shard_float16_params_this_group.append(shard_model_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1)[param_range.start : param_range.end] + shard_fp32_params_this_group.append(shard_model_param) + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + else: + raise TypeError( + 'Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(model_param.type()) + ) self._weights_offloaded = False @@ -208,16 +213,17 @@ def free_grad_buffers(self): log_rank_0('Call free_grad_buffers when already freed. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() grad_dtype_to_params = self._grad_dtype_to_params - # This is necessary, but don't know why. - optimizer.zero_grad(True) + for optimizer in optimizer_list: + # This is necessary, but don't know why. + optimizer.zero_grad(True) - if self._use_distributed_optimizer: - # Release param_buffers because they share storage with grad_buffers. - # Note: param_buffers are only available in DistributedOptimizer. - optimizer.param_buffers.clear() + if self._use_distributed_optimizer: + # Release param_buffers because they share storage with grad_buffers. + # Note: param_buffers are only available in DistributedOptimizer. + optimizer.param_buffers.clear() # Release grad_buffers, including buckets in GradBuffer for newer Megatron version. # Release `main_grad` of parameters. @@ -249,7 +255,7 @@ def build_grad_buffers(self): log_rank_0('Call build_grad_buffers when already built. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() params_dtype = self._params_dtype grad_dtype_to_params = self._grad_dtype_to_params @@ -283,31 +289,33 @@ def build_grad_buffers(self): return # Re-allocate param_buffers, see Megatron DistributedOptimizer#__init__. - optimizer.param_buffers = [] - for _, _ in enumerate(optimizer.models): - current_param_buffers = {} - for dtype, grad_buffer in self.get_grad_buffers().items(): - current_param_buffers[dtype] = [] - if self._megatron_version == MegatronVersion.V2: - for bucket in grad_buffer.buckets: + # pylint: disable=too-many-nested-blocks + for optimizer in optimizer_list: + optimizer.param_buffers = [] + for _, _ in enumerate(optimizer.models): + current_param_buffers = {} + for dtype, grad_buffer in self.get_grad_buffers().items(): + current_param_buffers[dtype] = [] + if self._megatron_version == MegatronVersion.V2: + for bucket in grad_buffer.buckets: + try: + storage = bucket.data.storage()._untyped() + # pylint: disable-next=bare-except + except: + storage = bucket.data.storage().untyped() + + param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage) + param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()] + current_param_buffers[dtype].append(param_buffer) + elif self._megatron_version == MegatronVersion.V1: try: - storage = bucket.data.storage()._untyped() + storage = grad_buffer.data.storage()._untyped() # pylint: disable-next=bare-except except: - storage = bucket.data.storage().untyped() - - param_buffer = torch.tensor([], dtype=params_dtype, device=bucket.data.device).set_(storage) - param_buffer = param_buffer[bucket.offset : bucket.offset + bucket.data.numel()] - current_param_buffers[dtype].append(param_buffer) - elif self._megatron_version == MegatronVersion.V1: - try: - storage = grad_buffer.data.storage()._untyped() - # pylint: disable-next=bare-except - except: - storage = grad_buffer.data.storage().untyped() - param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage) - param_buffer = param_buffer[: grad_buffer.numel_padded] - current_param_buffers[dtype] = param_buffer - optimizer.param_buffers.append(current_param_buffers) + storage = grad_buffer.data.storage().untyped() + param_buffer = torch.tensor([], dtype=params_dtype, device=grad_buffer.data.device).set_(storage) + param_buffer = param_buffer[: grad_buffer.numel_padded] + current_param_buffers[dtype] = param_buffer + optimizer.param_buffers.append(current_param_buffers) self._grad_buffers_freed = False diff --git a/chatlearn/models/megatron/memory_manager/trainer_v3.py b/chatlearn/models/megatron/memory_manager/trainer_v3.py index 335117e7..d96ab5a6 100644 --- a/chatlearn/models/megatron/memory_manager/trainer_v3.py +++ b/chatlearn/models/megatron/memory_manager/trainer_v3.py @@ -78,16 +78,15 @@ def offload_weights(self): log_rank_0('Call offload_weights when already offloaded. Ignore it.') return - optimizer = self._optimizer - - # TODO(jiqi): support expert parallel params + optimizer_list = self.get_optimizer_list() # In the V3 version, when distributed optimizer is used, parameter data are managed together with # gradients in buffers. if self._use_distributed_optimizer: - optimizer.shard_float16_groups.clear() - optimizer.shard_fp32_groups.clear() - optimizer.pbuf_view_items.clear() + for optimizer in optimizer_list: + optimizer.shard_float16_groups.clear() + optimizer.shard_fp32_groups.clear() + optimizer.pbuf_view_items.clear() if self._group_flat_weights is None: self._group_flat_weights = [] @@ -109,12 +108,14 @@ def offload_weights(self): bucket.param_data = None else: if self._group_flat_weights is None: - self._group_flat_weights = self._flat_param_groups( - [ - optimizer.float16_groups, - optimizer.fp32_from_fp32_groups, - ], - ) + self._group_flat_weights = [] + for optimizer in optimizer_list: + self._group_flat_weights.extend(self._flat_param_groups( + [ + optimizer.float16_groups, + optimizer.fp32_from_fp32_groups, + ], + )) # Offload param_data of buffers for flat_weights in self._group_flat_weights: @@ -132,7 +133,7 @@ def onload_weights(self): log_rank_0('Call onload_weights when already onloaded. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() # Onload param_data of buffers for flat_weights in self._group_flat_weights: @@ -172,52 +173,53 @@ def onload_weights(self): self._weights_offloaded = False return - optimizer.pbuf_view_items = optimizer._get_model_param_buffer_dp_views() - - shard_float16_groups = optimizer.shard_float16_groups - shard_fp32_groups = optimizer.shard_fp32_groups - param_gbuf_map = optimizer.model_param_gbuf_map - opt_group_ranges = optimizer.opt_group_ranges - model_gbuf_ranges = optimizer.gbuf_ranges - - # Rebuild shard_float16_groups and shard_fp32_groups, - # see Megatron DistributedOptimizer#build_model_and_main_param_groups. - for _, group_range in enumerate(opt_group_ranges): - shard_float16_params_this_group = [] - shard_fp32_params_this_group = [] - shard_float16_groups.append(shard_float16_params_this_group) - shard_fp32_groups.append(shard_fp32_params_this_group) - - for model_param in group_range["params"]: - assert model_param.requires_grad - gbuf_index, dtype, bucket_index = param_gbuf_map[model_param] - gbuf_range = model_gbuf_ranges[gbuf_index][dtype][bucket_index] - param_range = gbuf_range["param_map"][model_param]["param"] - - # fp16, bf16 params. - if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: - shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - - shard_float16_params_this_group.append(shard_model_param) - - # fp32 params. - elif model_param.type() == 'torch.cuda.FloatTensor': - shard_model_param = model_param.view(-1)[param_range.start : param_range.end] - shard_fp32_params_this_group.append(shard_model_param) - tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) - if hasattr(model_param, 'shared'): - shard_model_param.shared = model_param.shared - else: - raise TypeError( - 'Wrapped parameters must be one of ' - 'torch.cuda.FloatTensor, ' - 'torch.cuda.HalfTensor, or ' - 'torch.cuda.BFloat16Tensor. ' - 'Received {}'.format(model_param.type()) - ) + for optimizer in optimizer_list: + optimizer.pbuf_view_items = optimizer._get_model_param_buffer_dp_views() + + shard_float16_groups = optimizer.shard_float16_groups + shard_fp32_groups = optimizer.shard_fp32_groups + param_gbuf_map = optimizer.model_param_gbuf_map + opt_group_ranges = optimizer.opt_group_ranges + model_gbuf_ranges = optimizer.gbuf_ranges + + # Rebuild shard_float16_groups and shard_fp32_groups, + # see Megatron DistributedOptimizer#build_model_and_main_param_groups. + for _, group_range in enumerate(opt_group_ranges): + shard_float16_params_this_group = [] + shard_fp32_params_this_group = [] + shard_float16_groups.append(shard_float16_params_this_group) + shard_fp32_groups.append(shard_fp32_params_this_group) + + for model_param in group_range["params"]: + assert model_param.requires_grad + gbuf_index, dtype, bucket_index = param_gbuf_map[model_param] + gbuf_range = model_gbuf_ranges[gbuf_index][dtype][bucket_index] + param_range = gbuf_range["param_map"][model_param]["param"] + + # fp16, bf16 params. + if model_param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: + shard_model_param = model_param.detach().view(-1)[param_range.start : param_range.end] + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + + shard_float16_params_this_group.append(shard_model_param) + + # fp32 params. + elif model_param.type() == 'torch.cuda.FloatTensor': + shard_model_param = model_param.view(-1)[param_range.start : param_range.end] + shard_fp32_params_this_group.append(shard_model_param) + tensor_parallel.copy_tensor_model_parallel_attributes(shard_model_param, model_param) + if hasattr(model_param, 'shared'): + shard_model_param.shared = model_param.shared + else: + raise TypeError( + 'Wrapped parameters must be one of ' + 'torch.cuda.FloatTensor, ' + 'torch.cuda.HalfTensor, or ' + 'torch.cuda.BFloat16Tensor. ' + 'Received {}'.format(model_param.type()) + ) self._weights_offloaded = False @@ -229,10 +231,11 @@ def free_grad_buffers(self): log_rank_0('Call free_grad_buffers when already freed. Ignore it.') return - optimizer = self._optimizer + optimizer_list = self.get_optimizer_list() - # This is necessary, but don't know why. - optimizer.zero_grad(True) + for optimizer in optimizer_list: + # This is necessary, but don't know why. + optimizer.zero_grad(True) # Remove references from params for p, buffer in self._model.param_to_buffer.items(): diff --git a/chatlearn/models/megatron_module.py b/chatlearn/models/megatron_module.py index 772156ef..6aeacd8c 100644 --- a/chatlearn/models/megatron_module.py +++ b/chatlearn/models/megatron_module.py @@ -28,6 +28,7 @@ from chatlearn.utils.megatron_utils import build_pipeline_layer_name_mapping from chatlearn.models.megatron.memory_manager import create_trainer_memory_manager, InferenceMemoryManager except ImportError: + print("Megatron is not imported, setting mpu to None.") mpu = None from .torch_module import TorchModule diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index dd2eee22..dea10a5d 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -293,7 +293,8 @@ def build_rank_mapping_two_stage(self, add_recv_actor_fn=None): if local_src_ranks[0] is None or dst_ranks is None: if self._debug: logger.warning( - f"DEBUG MODE! src_ranks {local_src_ranks} or dst_ranks: {dst_ranks} is None, make sure they have values in real application.") + f"DEBUG MODE! src_ranks {local_src_ranks} or dst_ranks: {dst_ranks} is None, " + "make sure they have values in real application.") return else: raise Exception(f"src_ranks {local_src_ranks} or dst_ranks {dst_ranks} should not be None") diff --git a/chatlearn/tools/convert.py b/chatlearn/tools/convert.py new file mode 100644 index 00000000..0974a0c6 --- /dev/null +++ b/chatlearn/tools/convert.py @@ -0,0 +1,79 @@ +# Code below is modified from NVIDIA Megatron-LM +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Convertion script""" + +import argparse +import importlib +import sys +import torch.multiprocessing as mp + + +# Code below is copied from Megatron-LM core v0.8.0 +def load_plugin(plugin_type, name): + module_name = f"{plugin_type}_{name}" + try: + plugin = importlib.import_module(module_name) + except ModuleNotFoundError as e1: + print(e1) + module_name = name + try: + plugin = importlib.import_module(module_name) + except ModuleNotFoundError as e2: + print(e2) + sys.exit(f"Unable to load {plugin_type} plugin {name}. Exiting.") + + if not hasattr(plugin, 'add_arguments'): + sys.exit(f"{module_name} module is not a plugin. Exiting.") + + print(f"Loaded {module_name} as the {plugin_type}.") + return plugin + +def main(): + parser = argparse.ArgumentParser(description="Megatron Checkpoint Converter Arguments", + allow_abbrev=False, conflict_handler='resolve') + + parser.add_argument('--model-type', type=str, required=True, + choices=['GPT', 'BERT'], + help='Type of the model') + parser.add_argument('--loader', type=str, default='megatron', + help='Module name to load checkpoint, should be on python path') + parser.add_argument('--loader-prefix', type=str, default='loader', + help='Prefix import path for loader') + parser.add_argument('--saver', type=str, default='megatron', + help='Module name to save checkpoint, should be on python path') + parser.add_argument('--saver-prefix', type=str, default='saver', + help='Prefix import path for saver') + parser.add_argument('--load-dir', type=str, required=True, + help='Directory to load model checkpoint from') + parser.add_argument('--save-dir', type=str, required=True, + help='Directory to save model checkpoint to') + parser.add_argument('--max-queue-size', type=int, default=50, + help='Maximum number of tensors in the queue') + parser.add_argument('--no-checking', action='store_false', + help='Do not perform checking on the name and ordering of weights', + dest='checking') + + known_args, _ = parser.parse_known_args() + loader = load_plugin(known_args.loader_prefix, known_args.loader) + saver = load_plugin(known_args.saver_prefix, known_args.saver) + + loader.add_arguments(parser) + saver.add_arguments(parser) + + args = parser.parse_args() + + queue = mp.Queue(maxsize=args.max_queue_size) + + print("Starting saver...") + saver_proc = mp.Process(target=saver.save_checkpoint, args=(queue, args)) + saver_proc.start() + + print("Starting loader...") + loader.load_checkpoint(queue, args) + + print("Waiting for saver to complete...") + saver_proc.join() + + +if __name__ == '__main__': + main() diff --git a/chatlearn/tools/loader_mcore_mixtral.py b/chatlearn/tools/loader_mcore_mixtral.py new file mode 100644 index 00000000..1544271d --- /dev/null +++ b/chatlearn/tools/loader_mcore_mixtral.py @@ -0,0 +1,517 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""load mcore mixtral model""" + +import json +import os +import sys +import types +import torch + +from utils import get_mcore_transformer_block_key, print_memory_usage + + +def add_arguments(parser): + group = parser.add_argument_group(title='Megatron loader') + + group.add_argument('--true-vocab-size', type=int, default=None, + help='original size of vocab, if specified will trim padding from embedding table.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to the vocab file. If specified will use this to get vocab size and ' + 'trim padding from the embedding table.') + group.add_argument('--megatron-path', type=str, default=None, + help='Base directory of Megatron repository') + group.add_argument('--position-embedding-type', + type=str, + default='learned_absolute', + choices=['learned_absolute', 'rope'], + help='Position embedding type.') + group.add_argument('--loader-transformer-impl', default='transformer_engine', + choices=['local', 'transformer_engine'], + help='Which Transformer implementation to use.') + + +def _load_checkpoint(queue, args): + + # Search in directory above this + sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir))) + if args.megatron_path is not None: + sys.path.insert(0, args.megatron_path) + + # pylint: disable=import-outside-toplevel + try: + from megatron.training.arguments import parse_args, validate_args + from megatron.training.global_vars import set_global_variables + from megatron.training.checkpointing import load_args_from_checkpoint, load_checkpoint # pylint: disable=redefined-outer-name + from megatron.legacy.model import module + from megatron.core import mpu + from megatron.core.enums import ModelType + from megatron.legacy import fused_kernels + except ModuleNotFoundError: + print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") + queue.put("exit") + sys.exit(1) + + # We want all arguments to come from us + sys.argv = ['script.py', + '--no-masked-softmax-fusion', + '--no-bias-gelu-fusion', + '--no-bias-dropout-fusion', + '--no-async-tensor-model-parallel-allreduce', + '--use-cpu-initialization', + '--micro-batch-size', '1', + '--no-load-optim', + '--no-load-rng', + '--no-save-optim', + '--no-save-rng', + '--no-initialization', + '--mock-data', # To pass the "blend data checks" in arguments.py + '--load', args.load_dir, + '--position-embedding-type', args.position_embedding_type, + ] + + margs = parse_args() + margs, checkpoint_args = load_args_from_checkpoint(margs, exit_on_missing_checkpoint=True) + + # Arguments do sanity checks on the world size, but we don't care, + # so trick it into thinking we are plenty of processes + margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size * margs.expert_model_parallel_size + + # Explicitly copy data types from checkpoint. + margs.fp16 = checkpoint_args.fp16 + margs.bf16 = checkpoint_args.bf16 + + margs.use_legacy_models = False + margs.transformer_impl = args.loader_transformer_impl + if checkpoint_args.expert_model_parallel_size > 1: + margs.expert_model_parallel_size = checkpoint_args.expert_model_parallel_size + margs.num_experts = checkpoint_args.num_experts + + # Validate margs. + margs = validate_args(margs) + + def check_for_arg(arg_name, default=None): + if getattr(margs, arg_name, None) is None: + if default is not None: + setattr(margs, arg_name, default) + else: + print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") + print(f"Arguments: {margs}") + queue.put("exit") + sys.exit(1) + + check_for_arg('tensor_model_parallel_size') + check_for_arg('pipeline_model_parallel_size') + check_for_arg('expert_model_parallel_size') + check_for_arg('num_layers') + check_for_arg('hidden_size') + check_for_arg('seq_length') + check_for_arg('num_attention_heads') + check_for_arg('max_position_embeddings') + check_for_arg('position_embedding_type') + check_for_arg('tokenizer_type') + check_for_arg('iteration') + check_for_arg('bert_binary_head') + check_for_arg('disable_bias_linear', False) + check_for_arg('params_dtype') + check_for_arg('swiglu', False) + if checkpoint_args.expert_model_parallel_size > 1: + check_for_arg('num_experts') + + # Determine how to make our models + if args.model_type == 'GPT': + from pretrain_gpt import model_provider + margs.model_type = ModelType.encoder_or_decoder + elif args.model_type == 'BERT': + from pretrain_bert import model_provider + margs.model_type = ModelType.encoder_or_decoder + else: + raise Exception(f'unrecognized model type: {args.model_type}') + + # supress warning about torch.distributed not being initialized + module.MegatronModule.embedding_warning_printed = True + + consumed_train_samples = None + consumed_valid_samples = None + def get_models(tp_size, ep_size, dtype): + nonlocal consumed_train_samples + nonlocal consumed_valid_samples + model_array_len = margs.virtual_pipeline_model_parallel_size + if model_array_len is None: + model_array_len = 1 + models = [[[] for _ in range(ep_size)] for _ in range(model_array_len)] + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + for ep_rank in range(ep_size): + mpu.set_expert_model_parallel_rank(ep_rank) + for tp_rank in range(tp_size): + mpu.set_tensor_model_parallel_rank(tp_rank) + if margs.virtual_pipeline_model_parallel_size is not None: + model_ = [] + for i in range(margs.virtual_pipeline_model_parallel_size): + mpu.set_virtual_pipeline_model_parallel_rank(i) + # Set pre_process and post_process only after virtual rank is set. + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + this_model = model_provider( + pre_process=pre_process, + post_process=post_process + ).to(dtype) + model_.append(this_model) + else: + pre_process = mpu.is_pipeline_first_stage() + post_process = mpu.is_pipeline_last_stage() + model_ = [model_provider(pre_process, post_process).to(dtype)] + margs.consumed_train_samples = 0 + margs.consumed_valid_samples = 0 + margs.exit_on_missing_checkpoint = True + load_checkpoint(model_, None, None, strict=False) + + if consumed_train_samples is not None: + assert(margs.consumed_train_samples == consumed_train_samples) + else: + consumed_train_samples = margs.consumed_train_samples + if consumed_valid_samples is not None: + assert(margs.consumed_valid_samples == consumed_valid_samples) + else: + consumed_valid_samples = margs.consumed_valid_samples + for vp_rank in range(model_array_len): + models[vp_rank][ep_rank].append(model_[vp_rank]) + + # Print memory usage. + print_memory_usage("loader", tp_rank, tp_size) + + return models + + set_global_variables(margs, build_tokenizer=False) + mpu.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) + mpu.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) + mpu.set_expert_model_parallel_world_size(margs.expert_model_parallel_size) + mpu.set_virtual_pipeline_model_parallel_world_size(margs.virtual_pipeline_model_parallel_size) + fused_kernels.load(margs) + + # Get true (non-padded) vocab size + if args.true_vocab_size is not None: + true_vocab_size = args.true_vocab_size + elif args.vocab_file is not None: + with open(args.vocab_file) as vocab_file_handler: # pylint: disable=unspecified-encoding + vocab = json.load(vocab_file_handler) + true_vocab_size = len(vocab) + if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size: + print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.") + queue.put("exit") + sys.exit(1) + else: + true_vocab_size = None + + # short aliases + tp_size = margs.tensor_model_parallel_size + pp_size = margs.pipeline_model_parallel_size + ep_size = margs.expert_model_parallel_size + vp_size = margs.virtual_pipeline_model_parallel_size + if vp_size is None: + vp_size = 1 + + # Layernorm has bias; RMSNorm does not. + if hasattr(checkpoint_args, 'normalization'): + norm_has_bias = checkpoint_args.normalization == "LayerNorm" + else: + # older models only supported LayerNorm + norm_has_bias = True + + # metadata + md = types.SimpleNamespace() + md.model_type = args.model_type + md.num_layers = margs.num_layers + md.hidden_size = margs.hidden_size + md.seq_length = margs.seq_length + md.num_attention_heads = margs.num_attention_heads + md.max_position_embeddings = margs.max_position_embeddings + md.tokenizer_type = margs.tokenizer_type + md.iteration = margs.iteration + md.params_dtype = margs.params_dtype + md.bert_binary_head = margs.bert_binary_head + md.output_layer = margs.untie_embeddings_and_output_weights + md.position_embedding_type = margs.position_embedding_type + md.linear_bias = margs.add_bias_linear + md.norm_has_bias = norm_has_bias + md.swiglu = margs.swiglu + md.previous_tensor_parallel_size = margs.tensor_model_parallel_size + md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size + md.previous_expert_parallel_size = margs.expert_model_parallel_size + md.true_vocab_size = true_vocab_size + md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by + md.checkpoint_args = checkpoint_args + md.use_legacy_models = margs.use_legacy_models + md.num_experts = margs.num_experts + + # Get transformer block (named either 'encoder' or 'decoder'). + transformer_block_key = get_mcore_transformer_block_key(md.model_type) + def get_transformer_block(_model): + return getattr(_model, transformer_block_key) + + # Get first pipe stage + mpu.set_pipeline_model_parallel_rank(0) + # all_models: pp_rank, vp_rank, ep_rank, tp_rank + all_models = [get_models(tp_size, ep_size, md.params_dtype)] + models = all_models[0][0] + if ep_size == 1: + assert len(models) == 1 + + md.consumed_train_samples = consumed_train_samples + md.consumed_valid_samples = consumed_valid_samples + queue.put(md) + + def queue_put(name, msg): + print(f"sending {name}") + msg["name"] = name + queue.put(msg) + + # Send embeddings + message = { + "word embeddings": torch.cat( + [models[0][tp_rank].embedding.word_embeddings.weight.data for tp_rank in range(tp_size)], + dim = 0) + } + if md.position_embedding_type == 'learned_absolute': + message["position embeddings"] = models[0].embedding.position_embeddings.weight.data + else: + assert not hasattr(models[0][0].embedding, 'position_embeddings') + + queue_put("embeddings", message) + + def get_message_for_dense_model(message): + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(models[0][0]).layers[layer_num] + message["input norm weight"] = layer.self_attention.linear_qkv.layer_norm_weight.data + if norm_has_bias: + message["input norm bias"] = layer.self_attention.linear_qkv.layer_norm_bias.data + message["post norm weight"] = layer.mlp.linear_fc1.layer_norm_weight.data + if norm_has_bias: + message["post norm bias"] = layer.mlp.linear_fc1.layer_norm_bias.data + if md.linear_bias: + message["dense bias"] = layer.self_attention.linear_proj.bias.data + + # Grab all parallel tensors for this layer + qkv_weight = [] + qkv_bias = [] + dense_weight = [] + mlp_l0_weight = [] + mlp_l0_bias = [] + mlp_l1_weight = [] + for tp_rank, model in enumerate(models[0]): + layer = get_transformer_block(model).layers[layer_num] + qkv_weight.append(layer.self_attention.linear_qkv.weight.data) + dense_weight.append(layer.self_attention.linear_proj.weight.data) + mlp_l0_weight.append(layer.mlp.linear_fc1.weight.data) + mlp_l1_weight.append(layer.mlp.linear_fc2.weight.data) + if md.linear_bias: + qkv_bias.append(layer.self_attention.linear_qkv.bias.data) + mlp_l0_bias.append(layer.mlp.linear_fc1.bias.data) + if md.linear_bias: + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(models[0][0]).layers[layer_num] + mlp_l1_bias = layer.mlp.linear_fc2.bias.data + + # Handle gated linear units + if md.swiglu: + # concat all the first halves ('W's) and all the second halves ('V's) + for tp_rank in range(tp_size): + mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) + message["mlp l0 weight W"] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) + message["mlp l0 weight V"] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + else: + message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) + + # simple concat of the rest + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["dense weight"] = torch.cat(dense_weight, dim=1) + message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) + if md.linear_bias: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) + message["mlp l0 bias W"] = torch.cat([b[0] for b in mlp_l0_bias],dim=0) + message["mlp l0 bias V"] = torch.cat([b[1] for b in mlp_l0_bias],dim=0) + else: + message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) + message["mlp l1 bias"] = mlp_l1_bias + + def get_message_for_moe_model(message): + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(models[0][0]).layers[layer_num] + message["input norm weight"] = layer.self_attention.linear_qkv.layer_norm_weight.data + if norm_has_bias: + message["input norm bias"] = layer.self_attention.linear_qkv.layer_norm_bias.data + message["post norm weight"] = layer.pre_mlp_layernorm.weight.data + if norm_has_bias: + message["post norm bias"] = layer.pre_mlp_layernorm.bias.data + if md.linear_bias: + message["dense bias"] = layer.self_attention.linear_proj.bias.data + + # Grab all parallel tensors for this layer + qkv_weight = [] + qkv_bias = [] + dense_weight = [] + mlp_l0_weight_list = [[] for _ in range(margs.num_experts)] + mlp_l0_bias_list = [[] for _ in range(margs.num_experts)] + mlp_l1_weight_list = [[] for _ in range(margs.num_experts)] + mlp_l1_bias_list = [[] for _ in range(margs.num_experts)] + router_weight = [] + + # Dense modules + for tp_rank, model in enumerate(models[0]): + layer = get_transformer_block(model).layers[layer_num] + qkv_weight.append(layer.self_attention.linear_qkv.weight.data) + dense_weight.append(layer.self_attention.linear_proj.weight.data) + if md.linear_bias: + qkv_bias.append(layer.self_attention.linear_qkv.bias.data) + layer = get_transformer_block(models[0][0]).layers[layer_num] + router_weight = layer.mlp.router.weight.data + + # MoE modules + num_experts_per_rank = margs.num_experts // ep_size + for ep_rank, tp_models in enumerate(models): + for tp_rank, model in enumerate(tp_models): + layer = get_transformer_block(model).layers[layer_num] + for local_expert_idx in range(num_experts_per_rank): + expert_idx = int(ep_rank * num_experts_per_rank + local_expert_idx) + mlp_l0_weight_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc1.weight.data) + mlp_l1_weight_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc2.weight.data) + if md.linear_bias: + mlp_l0_bias_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc1.bias.data) + + if md.linear_bias: + # Get non-parallel tensors from tp_rank 0 + layer = get_transformer_block(tp_models[0]) + for local_expert_idx in range(num_experts_per_rank): + expert_idx = int(ep_rank * num_experts_per_rank + local_expert_idx) + mlp_l1_bias_list[expert_idx].append(layer.mlp.experts.local_experts[local_expert_idx].linear_fc2.bias.data) + + mlp_l0_weight_w_list = [[] for _ in range(margs.num_experts)] + mlp_l0_weight_v_list = [[] for _ in range(margs.num_experts)] + # Concat along the tensor parallel dimension + for expert_idx in range(margs.num_experts): + mlp_l0_weight = mlp_l0_weight_list[expert_idx] + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_weight[tp_rank] = torch.chunk(mlp_l0_weight[tp_rank], 2, dim=0) + mlp_l0_weight_w_list[expert_idx] = torch.cat([w[0] for w in mlp_l0_weight], dim=0) + mlp_l0_weight_v_list[expert_idx] = torch.cat([w[1] for w in mlp_l0_weight], dim=0) + else: + mlp_l0_weight_list[expert_idx] = torch.cat(mlp_l0_weight, dim=0) + mlp_l1_weight_list[expert_idx] = torch.cat(mlp_l1_weight_list[expert_idx], dim=1) + + # Stack along the expert parallel dimension + if md.swiglu: + message["mlp l0 weight W"] = torch.stack(mlp_l0_weight_w_list) + message["mlp l0 weight V"] = torch.stack(mlp_l0_weight_v_list) + else: + message["mlp l0 weight"] = torch.stack(mlp_l0_weight_list) + message["mlp l1 weight"] = torch.stack(mlp_l1_weight_list) + + # Concat along TP and stack along EP to biases + if md.linear_bias: + mlp_l0_bias_w_list = [[] for _ in range(margs.num_experts)] + mlp_l0_bias_v_list = [[] for _ in range(margs.num_experts)] + # Concat along the tensor parallel dimension + for expert_idx in range(margs.num_experts): + mlp_l0_bias = mlp_l0_bias_list[expert_idx] + if md.swiglu: + for tp_rank in range(tp_size): + mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) + mlp_l0_bias_w_list[expert_idx] = torch.cat([w[0] for w in mlp_l0_bias], dim=0) + mlp_l0_bias_v_list[expert_idx] = torch.cat([w[1] for w in mlp_l0_bias], dim=0) + else: + mlp_l0_bias_list[expert_idx] = torch.cat(mlp_l0_bias, dim=0) + assert len(mlp_l1_bias_list[expert_idx]) == 1 + mlp_l1_bias_list[expert_idx] = mlp_l1_bias_list[expert_idx][0] + + # Stack along the expert parallel dimension + if md.swiglu: + message["mlp l0 bias W"] = torch.stack(mlp_l0_bias_w_list) + message["mlp l0 bias V"] = torch.stack(mlp_l0_bias_v_list) + else: + message["mlp l0 bias"] = torch.stack(mlp_l0_bias_list) + message["mlp l1 bias"] = torch.stack(mlp_l1_bias_list) + + # Simple concat of the rest + message["qkv weight"] = torch.cat(qkv_weight, dim=0) + message["dense weight"] = torch.cat(dense_weight, dim=1) + if md.linear_bias: + message["qkv bias"] = torch.cat(qkv_bias, dim=0) + + # Do nothing to router + message["router weight"] = router_weight + + total_layer_num = 0 + for vp_rank in range(vp_size): + mpu.set_virtual_pipeline_model_parallel_rank(vp_rank) + for pp_rank in range(pp_size): + if pp_rank > 0: + mpu.set_pipeline_model_parallel_rank(pp_rank) + if vp_rank == 0: + all_models.append(get_models(tp_size, ep_size, md.params_dtype)) + models = all_models[pp_rank][vp_rank] + for layer_num in range(len(get_transformer_block(models[0][0]).layers)): + message = {} + + if margs.num_experts: + get_message_for_moe_model(message) + else: + get_message_for_dense_model(message) + + queue_put(f"transformer layer {total_layer_num}", message) + + total_layer_num = total_layer_num + 1 + + # Send final norm from tp_rank 0 + message = { + "weight": get_transformer_block(models[0][0]).final_layernorm.weight.data, + } + if norm_has_bias: + message["bias"] = get_transformer_block(models[0][0]).final_layernorm.bias.data + queue_put("final norm", message) + + if md.output_layer: + message = { + "weight": torch.cat( + [models[0][tp_rank].output_layer.weight.data for tp_rank in range(tp_size)], + dim = 0) + } + queue_put("output layer", message) + + + # Send BERT lm head and binary head if it exists + if md.model_type == 'BERT': + message = { + "weight": models[0][0].pooler.dense.weight.data, + "bias": models[0][0].pooler.dense.bias.data + } + queue_put("pooler", message) + + message = { + "dense weight": models[0][0].lm_head.dense.weight.data, + "dense bias": models[0][0].lm_head.dense.bias.data, + "norm weight": models[0][0].lm_head.layer_norm.weight.data, + } + if norm_has_bias: + message["norm bias"] = models[0][0].lm_head.layer_norm.bias.data + queue_put("lm head", message) + + if md.bert_binary_head: + message = { + "weight": models[0][0].binary_head.weight.data, + "bias": models[0][0].binary_head.bias.data + } + queue_put("binary head", message) + queue.put("done") + +def load_checkpoint(queue, args): + try: + _load_checkpoint(queue, args) + except: + queue.put("exit") + raise diff --git a/chatlearn/tools/megatron_to_hf.py b/chatlearn/tools/megatron_to_hf.py index 7a990744..59576ccd 100644 --- a/chatlearn/tools/megatron_to_hf.py +++ b/chatlearn/tools/megatron_to_hf.py @@ -76,6 +76,20 @@ def add_checkpointing_args(parser): "Path to Megatron-LM" ), ) + parser.add_argument( + "--use_legacy_models", + action="store_true", + help=( + "Whether using legacy models. Default: False." + ) + ) + parser.add_argument( + "--validate_checkpoint", + action="store_true", + help=( + "Whether validating converted checkpoint. Default: False." + ) + ) return parser @@ -97,6 +111,16 @@ def add_checkpointing_args(parser): "mlp.dense_4h_to_h.weight" ] + +mcore_to_transformers = { + "self_attention.linear_proj":".self_attn.o_proj.", + "linear_fc1_1":".w1.", + "linear_fc1_2":".w3.", + "linear_fc2":".w2.", + "mlp.router":".block_sparse_moe.gate.", + "self_attention.rotary_emb":".self_attn.rotary_emb.inv_freq" # unneeded for MoE +} + def recursive_print(name, val, spaces=0): """ Recursively print the structure of a checkpoint. This function is taken from `convert_megatron_gpt2_checkpoint.py` @@ -383,13 +407,259 @@ def convert_checkpoint_from_megatron_to_transformers(args): os.system(f"cp {fn} {args.save_path}") +def convert_checkpoint_from_mcore_to_transformers(args): + """ + Convert NVIDIA MCore checkpoint to HuggingFace Transformers checkpoint. It saves the converted checkpoint into shards + using HuggingFace Transformers checkpoint sharding functionality. + + Args: + args (argparse.Namespace): the arguments to the script + """ + # Load Megatron-Core checkpoint arguments from the state dict + possible_sub_dirs = ["mp_rank_00", "mp_rank_00_000", "mp_rank_00_000_000"] + for root, dirnames, _ in os.walk(args.load_path): + for dirname in dirnames: + if dirname in possible_sub_dirs: + rank0_checkpoint_name = glob.glob(os.path.join(root, dirname) + "/*.pt") + args.load_path = root + rank0_checkpoint_path = rank0_checkpoint_name[0] + + print(f"Loading Megatron-Core checkpoint arguments from: {rank0_checkpoint_path}") + state_dict = torch.load(rank0_checkpoint_path, map_location="cpu") + megatron_args = state_dict.get("args", None) + if megatron_args is None: + raise ValueError( + "Megatron-Core checkpoint does not contain arguments. This utility only supports Megatron-Core checkpoints" + " containing all the megatron arguments. This is because it loads all config related to model" + " architecture, the tensor and pipeline model parallel size from the checkpoint insead of user having to" + " manually specify all the details. Please save Megatron-Core checkpoint along with all the megatron" + " arguments to use this utility." + ) + + # params dtype + if args.target_params_dtype == "fp16": + dtype = torch.float16 + elif args.target_params_dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + output_state_dict = {} + + checkpoint_version = state_dict.get("checkpoint_version", 0.0) + assert checkpoint_version >= 3.0 + tp_size = megatron_args.tensor_model_parallel_size + pp_size = megatron_args.pipeline_model_parallel_size + ep_size = megatron_args.expert_model_parallel_size + assert tp_size == 1 and pp_size == 1 and ep_size == 1 + + # Possible keys for MoE models: + # 'embedding.word_embeddings.weight', + # 'decoder.layers.0.self_attention.linear_proj.weight', + # 'decoder.layers.0.self_attention.linear_proj._extra_state', + # 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight', + # 'decoder.layers.0.self_attention.linear_qkv.weight', + # 'decoder.layers.0.self_attention.linear_qkv._extra_state', + # 'decoder.layers.0.pre_mlp_layernorm.weight', + # 'decoder.layers.0.mlp.router.weight', + # 'decoder.layers.0.mlp.experts.local_experts.0.linear_fc1.weight', + # 'decoder.layers.0.mlp.experts.local_experts.0.linear_fc1._extra_state', + # 'decoder.layers.0.mlp.experts.local_experts.0.linear_fc2.weight', + # ..., + # 'decoder.final_layernorm.weight', + # 'output_layer.weight', + # 'output_layer._extra_state', + # 'decoder' + # The regex to extract layer names. + layer_re = re.compile(r"decoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z_]+)") + expert_re = re.compile(r"decoder.layers\.(\d+)\.([a-z0-9_.]+)\.(\d+)\.([a-z0-9_.]+)\.(weight|bias|_extra_state)") + + # Convert. + print("Converting") + + # Embeddings + print("Converting embeddings") + tp_state_dicts = get_megatron_sharded_states(args, tp_size, pp_size, 0) + state_dict = tp_state_dicts[0]['model'] + + # Convert and store the position embeddings. + position_embeddings = state_dict.get("embedding.position_embeddings.weight", None) + if position_embeddings: + output_state_dict["transformer.position_embeddings.weight"] = position_embeddings.to(dtype) + + # Convert and store the word embeddings. + word_embedding = state_dict.get("embedding.word_embeddings.weight", None) + output_state_dict["model.embed_tokens.weight"] = word_embedding.to(dtype) + + # Transformer Layers + print("Converting transformer layers") + + def process_dense(layer_match_res, output_state_dict): + # The name of the operation. + op_name = layer_match_res.group(2) + # Is it a weight or a bias? + weight_or_bias = layer_match_res.group(3) + + # Ignore them + if weight_or_bias in ('bias', '_extra_state'): + return + params = val.to(dtype) + + # For norm(s), simply store the norm. + if weight_or_bias.endswith("norm_weight"): # e.g. self_attention.linear_qkv.layer_norm_weight + ln_name = "input_layernorm" + output_state_dict[layer_name + "." + ln_name + ".weight"] = params + elif op_name.endswith("norm") and weight_or_bias == 'weight': # e.g. pre_mlp_layernorm.weight + ln_name = "post_attention_layernorm" + output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = params + + # Transpose the QKV matrix. + elif op_name == "self_attention.linear_qkv" \ + and weight_or_bias == "weight": + q_proj, k_proj, v_proj = split_attn_state(params, megatron_args) + if args.model_type == "llama": + output_state_dict[layer_name + ".self_attn.q_proj.weight"] = q_proj + output_state_dict[layer_name + ".self_attn.k_proj.weight"] = k_proj + output_state_dict[layer_name + ".self_attn.v_proj.weight"] = v_proj + + # Store other weights such as router + elif weight_or_bias == "weight": + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + out_name + "weight"] = params + + # Copy the Rotary Embedding + else: + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + out_name] = params + + def process_moe(expert_match_res, output_state_dict): + # The prefix of the expert + expert_prefix = expert_match_res.group(2) + # The index of the expert + expert_idx = expert_match_res.group(3) + # the name of the operation + op_name = expert_match_res.group(4) + # Is it a weight or a bias? + weight_or_bias = expert_match_res.group(5) + + # Ignore them + if weight_or_bias in ('bias', '_extra_state'): + return + params = val.to(dtype) + + expert_name = f".block_sparse_moe.experts.{expert_idx}" + if 'linear_fc1' in op_name: + linear_fc1_1, linear_fc1_2 = torch.split(params, params.size(0)//2, 0) + out_name = mcore_to_transformers[op_name+'_1'] + output_state_dict[layer_name + expert_name + out_name + "weight"] = linear_fc1_1 + out_name = mcore_to_transformers[op_name+'_2'] + output_state_dict[layer_name + expert_name + out_name + "weight"] = linear_fc1_2 + elif 'linear_fc2' in op_name: + out_name = mcore_to_transformers[op_name] + output_state_dict[layer_name + expert_name + out_name + "weight"] = params + else: + assert False, f"Unrecognized MoE module {expert_prefix}.{expert_idx}.{op_name}" + + # Extract the layers. + for key, val in state_dict.items(): + # Match the name. + layer_match_res = layer_re.match(key) + expert_match_res = expert_re.match(key) + # Continue if that's not a layer + if layer_match_res is None: + continue + if val is None: + continue + + # The index of the layer. + layer_idx = int(layer_match_res.group(1)) + # The name of the layer. + layer_name = f"model.layers.{layer_idx}" + + if expert_match_res: # Deal with sparse layers + process_moe(expert_match_res, output_state_dict) + else: # Deal with dense layers + process_dense(layer_match_res, output_state_dict) + + if megatron_args.num_layers != (layer_idx + 1): + raise ValueError(f"Expected {megatron_args.num_layers} layers but found {layer_idx + 1}") + + # The final norm. + print("Converting final norm") + params = state_dict.get("decoder.final_layernorm.weight", None) + output_state_dict["model.norm.weight"] = params.to(dtype) + + # For LM head, transformers' wants the matrix to weight embeddings. + print("Converting LM head") + params = state_dict.get('output_layer.weight', None) + output_state_dict["lm_head.weight"] = params.to(dtype) + + print("Saving checkpoint...") + + # Print the structure of converted state dict. + if args.print_checkpoint_structure: + recursive_print(None, output_state_dict) + + # Store the state_dict to file. + max_shard_size = int(args.max_shard_size) if args.max_shard_size.isdigit() else args.max_shard_size + shards, index = shard_checkpoint(output_state_dict, max_shard_size=max_shard_size) + + # Save the model + if not os.path.exists(args.save_path): + os.system(f'mkdir -p {args.save_path}') + for shard_file, shard in shards.items(): + torch.save(shard, os.path.join(args.save_path, shard_file)) + + if index is None: + print(f"Model weights saved in {os.path.join(args.save_path, WEIGHTS_NAME)}") + else: + save_index_file = os.path.join(args.save_path, WEIGHTS_INDEX_NAME) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + print( + f"The model is bigger than the maximum size per checkpoint ({args.max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + # Saving config and tokenzier files + for fn in glob.glob(args.vocab_dir + "/*"): + if (fn.endswith(".json") or fn.endswith("tokenizer.model") or fn.endswith(".py")) and not fn.endswith(".index.json"): + os.system(f"cp {fn} {args.save_path}") + + # It should be done! + print("Conversion from Megatron-Core to Transformers is done!") + +# pylint: disable=import-outside-toplevel +def validate_loading_checkpoints(args): + from transformers import AutoModelForCausalLM + _, model_loading_info = AutoModelForCausalLM.from_pretrained(args.save_path, output_loading_info=True) + if len(model_loading_info["missing_keys"]) > 0: + assert False, f"Invalid model checkpoint on missing_keys: {model_loading_info['missing_keys']}" + if len(model_loading_info["unexpected_keys"]) > 0: + assert False, f"Invalid model checkpoint on unexpected_keys: {model_loading_info['unexpected_keys']}" + if len(model_loading_info["mismatched_keys"]) > 0: + assert False, f"Invalid model checkpoint on mismatched_keys: {model_loading_info['mismatched_keys']}" + if len(model_loading_info["error_msgs"]) > 0: + assert False, f"Invalid model checkpoint on error_msgs: {model_loading_info['error_msgs']}" + def main(): parser = argparse.ArgumentParser() parser = add_checkpointing_args(parser) args = parser.parse_args() if args.megatron_path: sys.path.append(args.megatron_path) - convert_checkpoint_from_megatron_to_transformers(args) + + if args.use_legacy_models: + convert_checkpoint_from_megatron_to_transformers(args) + else: + convert_checkpoint_from_mcore_to_transformers(args) + + if args.validate_checkpoint: + print("Validating converted checkpoints...") + validate_loading_checkpoints(args) + print("Validation success!") if __name__ == "__main__": diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index b2aab309..d5ecc829 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -207,7 +207,7 @@ class ModelConfig(BaseConfig): tensor_model_parallel_size: int = None #: [optional] pipeline model parallel size pipeline_model_parallel_size: int = None - #: [optional] expert model parallel size for Megatron-Core + #: [optional] expert model parallel size expert_model_parallel_size: int = None #: [optional] zero size zero_size: int = None diff --git a/chatlearn/utils/megatron_import_helper.py b/chatlearn/utils/megatron_import_helper.py index 50a74bfe..c369d02d 100644 --- a/chatlearn/utils/megatron_import_helper.py +++ b/chatlearn/utils/megatron_import_helper.py @@ -160,12 +160,14 @@ from megatron.optimizer import DistributedOptimizer from megatron.optimizer.optimizer import MegatronOptimizer from megatron.optimizer.optimizer import MixedPrecisionOptimizer + from megatron.optimizer.optimizer import ChainedOptimizer from megatron.optimizer.optimizer import Float16OptimizerWithFloat16Params except ImportError: from megatron.core.optimizer import get_megatron_optimizer from megatron.core.optimizer import DistributedOptimizer from megatron.core.optimizer.optimizer import MegatronOptimizer from megatron.core.optimizer.optimizer import MixedPrecisionOptimizer + from megatron.core.optimizer.optimizer import ChainedOptimizer from megatron.core.optimizer.optimizer import Float16OptimizerWithFloat16Params # DistributedDataParallel @@ -214,13 +216,14 @@ reduce_scatter_to_sequence_parallel_region ) -# pylint: enable=unused-import +try: + from megatron.training import save_checkpoint_and_time as megatron_save_checkpoint_and_time +except ImportError: + from megatron.training.training import save_checkpoint_and_time as megatron_save_checkpoint_and_time def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler): try: - from megatron.training import save_checkpoint_and_time as save_checkpoint_and_time_v1 # pylint: disable=import-outside-toplevel - save_checkpoint_and_time_v1(iteration, model, optimizer, opt_param_scheduler) - except ImportError: - from megatron.training.training import save_checkpoint_and_time as save_checkpoint_and_time_v2# pylint: disable=import-outside-toplevel - save_checkpoint_and_time_v2(iteration, model, optimizer, opt_param_scheduler, 0, None) + megatron_save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler) + except TypeError: # missing required positional arguments for new Megatron version + megatron_save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, 0, None) diff --git a/chatlearn/utils/megatron_utils.py b/chatlearn/utils/megatron_utils.py index a1816c10..5f778ccc 100644 --- a/chatlearn/utils/megatron_utils.py +++ b/chatlearn/utils/megatron_utils.py @@ -170,15 +170,20 @@ def load_checkpoint(*_args, **kwargs): args = get_args() target_tp = args.tensor_model_parallel_size target_pp = args.pipeline_model_parallel_size + target_ep = args.expert_model_parallel_size state_dict, _, _ = _load_base_checkpoint(args.load, rank0=True) args.iteration = state_dict['iteration'] checkpoint_args = state_dict['args'] checkpoint_tp = checkpoint_args.tensor_model_parallel_size checkpoint_pp = checkpoint_args.pipeline_model_parallel_size - if target_tp != checkpoint_tp or target_pp != checkpoint_pp: + checkpoint_ep = checkpoint_args.expert_model_parallel_size + if target_tp != checkpoint_tp or target_pp != checkpoint_pp or target_ep != checkpoint_ep: script_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tools/megatron_checkpoint_utils.py") save_dir = args.load[:-1] if args.load.endswith("/") else args.load - save_dir = save_dir + f"-transform-tp{target_tp}-pp{target_pp}" + if target_ep is None or target_ep == 1: + save_dir = save_dir + f"-transform-tp{target_tp}-pp{target_pp}" + else: + save_dir = save_dir + f"-transform_tp{target_tp}-pp{target_pp}-ep{target_ep}" if not os.path.exists(save_dir): # use last rank so we can determin model_type by whether last pipeline stage contains pooler_head if torch.distributed.get_rank() == (torch.distributed.get_world_size() - 1): diff --git a/examples/megatron/configs/mixtral/base.yaml b/examples/megatron/configs/mixtral/base.yaml new file mode 100644 index 00000000..c1f0ae8f --- /dev/null +++ b/examples/megatron/configs/mixtral/base.yaml @@ -0,0 +1,83 @@ +# mixtral-8x7b config +add_position_embedding: False +use_rotary_position_embeddings: True +untie_embeddings_and_output_weights: True +tokenizer_type: Llama2Tokenizer +exit_on_missing_checkpoint: True +normalization: RMSNorm +masked_softmax_fusion: False +apply_query_key_layer_scaling: False +use_checkpoint_args: False +add_bias_linear: False +swiglu: True +attention_softmax_in_fp32: True +transformer_impl: transformer_engine +bf16: True + + +trainer_engine: ${trainer_engine:rlhf} +init_shuffle_prompts: ${init_shuffle_prompts:0} +# dpo loss +use_ipo: ${use_ipo:False} +dpo_weight: ${dpo_weight:0.1} + +train_to_compare_num_responses: ${train_to_compare_num_responses:1} +num_inference_per_prompt: ${num_inference_per_prompt:1} +tokenizer_model: ${tokenizer_model} +max_position_embeddings: ${max_position_embedding:4096} +seq_length: ${seq_length:1024} +fix_kl_coef: ${fix_kl_coef:True} +log_dir: ${log_dir} +exp_name: ${exp_name:test} +tensorboard_dir: ${tensorboard_dir} +loss_on_prompts: ${loss_on_prompts:False} +numerical_stable: True + +build_path: ${build_path:build} + + +init_kl_coef: ${init_kl_coef:0.02} +target: 6 +horizon: 10000 +gamma: 1 +lam: 0.95 +cliprange: 0.2 +cliprange_value: ${cliprange_value:0.1} +scale_reward: "None" + +cliprange_reward: 100 + +max_new_tokens: ${max_new_tokens:512} + + +ngram_coef: ${ngram_coef:1} +lm_coef: ${lm_coef:0} +math_coef: ${math_coef:0} +raw_reward_coeff: ${raw_reward_coeff:1} + +clipped_value_only: ${clipped_value_only:1} +finetune: True + + +save: ${save_dir} +save_interval: 1000 +gradient_accumulation_fusion: 0 +max_tokens_to_oom: 99999999 + + +hysteresis: 2 +use_flash_attn: 1 +do_math_eval: 0 +log_entropy: False +adaptive_parallel_strategy_on_checkpoint: True +log_interval: 1 +distributed_timeout_minutes: 30 +make_vocab_size_divisible_by: 32 +use_legacy_models: ${use_legacy_models:False} # disable legacy mode for MoE models +use_dist_ckpt: ${use_dist_ckpt:True} # use dist_ckpt for MoE models + +# MoE +moe_router_load_balancing_type: ${moe_router_load_balancing_type:"aux_loss"} +moe_aux_loss_coeff: ${moe_aux_loss_coeff:1e-2} +moe_grouped_gemm: ${moe_grouped_gemm:True} +moe_token_dispatcher_type: ${moe_token_dispatcher_type:"alltoall"} diff --git a/examples/megatron/configs/mixtral/base_inference.yaml b/examples/megatron/configs/mixtral/base_inference.yaml new file mode 100644 index 00000000..e8693e33 --- /dev/null +++ b/examples/megatron/configs/mixtral/base_inference.yaml @@ -0,0 +1,16 @@ +includes: + - base.yaml + + +temperature: 1.0 +seed: 42 +no_load_optim: True +no_load_rng: True +no_load_args: True +no_load_scheduler: True +log_num_zeros_in_grad: True +attention_dropout: 0.0 +hidden_dropout: 0.0 +retro_encoder_attention_dropout: 0.0 +retro_encoder_hidden_dropout: 0.0 +inference_batch_times_seqlen_threshold: ${inference_batch_times_seqlen_threshold:4096} diff --git a/examples/megatron/configs/mixtral/base_train.yaml b/examples/megatron/configs/mixtral/base_train.yaml new file mode 100644 index 00000000..3c17b5ad --- /dev/null +++ b/examples/megatron/configs/mixtral/base_train.yaml @@ -0,0 +1,10 @@ +includes: + - base.yaml + + +distributed_backend: nccl +train_iters: 12000 + +clip_grad: ${clip_grad:0.5} +log_interval: 1 +log_num_zeros_in_grad: True diff --git a/examples/megatron/configs/mixtral/dpo.yaml b/examples/megatron/configs/mixtral/dpo.yaml new file mode 100644 index 00000000..d66150f1 --- /dev/null +++ b/examples/megatron/configs/mixtral/dpo.yaml @@ -0,0 +1,46 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + reference: + model_config_file: reference.yaml + num_gpu: ${num_gpu_ref:16} + trainable: False + generation_batch_size: ${ref_generation_batch_size:4} + free_memory: ${free_memory_reference:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + +runtime: + colocation: + - ppo_policy,reference + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:200} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:100} + data_path: ${data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + output_dir: ${output_dir} + free_sync_collective_group: ${free_sync_collective_group:False} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/configs/mixtral/old_policy_inference.yaml b/examples/megatron/configs/mixtral/old_policy_inference.yaml new file mode 100644 index 00000000..9faa42ac --- /dev/null +++ b/examples/megatron/configs/mixtral/old_policy_inference.yaml @@ -0,0 +1,15 @@ +includes: + - base_inference.yaml + - policy_shared.yaml + + +top_p: ${policy_top_p:0.9} +top_k: ${policy_top_k:0} +temperature: ${policy_temperature:1.0} + +eval_temperature: 0.01 +use_attn_acc: ${use_attn_acc:False} +eval_top_k: 1 +eval_top_p: 0 + +pipeline_model_parallel_size: ${policy_pp:1} diff --git a/examples/megatron/configs/mixtral/old_value_inference.yaml b/examples/megatron/configs/mixtral/old_value_inference.yaml new file mode 100644 index 00000000..54fc4dfc --- /dev/null +++ b/examples/megatron/configs/mixtral/old_value_inference.yaml @@ -0,0 +1,5 @@ +includes: + - base_inference.yaml + - reward_shared.yaml + +pipeline_model_parallel_size: ${value_pp:1} diff --git a/examples/megatron/configs/mixtral/online_dpo.yaml b/examples/megatron/configs/mixtral/online_dpo.yaml new file mode 100644 index 00000000..1d8cb509 --- /dev/null +++ b/examples/megatron/configs/mixtral/online_dpo.yaml @@ -0,0 +1,63 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu_policy:16} + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + free_memory: ${free_memory_policy:False} + + reference: + model_config_file: reference.yaml + num_gpu: ${num_gpu_ref:16} + trainable: False + generation_batch_size: ${ref_generation_batch_size:4} + free_memory: ${free_memory_reference:False} + + reward: + model_config_file: reward_inference.yaml + num_gpu: ${num_gpu_reward:16} + trainable: False + free_memory: ${free_memory_reward:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + +runtime: + colocation: + - policy,ppo_policy,reward,reference + generation_batch_size: ${generation_batch_size:4} + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:200} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:100} + data_path: ${data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + output_dir: ${output_dir} + free_sync_collective_group: ${free_sync_collective_group:False} + exp_name: ${exp_name:chatlearn} + validate_param_sync: ${validate_param_sync:False} diff --git a/examples/megatron/configs/mixtral/policy_shared.yaml b/examples/megatron/configs/mixtral/policy_shared.yaml new file mode 100644 index 00000000..7b314f7b --- /dev/null +++ b/examples/megatron/configs/mixtral/policy_shared.yaml @@ -0,0 +1,13 @@ +load: ${policy_inference_load} +load_iteration: ${policy_load_iteration} +num_layers: ${policy_num_layers} +hidden_size: ${policy_hidden_size} +num_attention_heads: ${policy_num_attention_heads} +ffn_hidden_size: ${policy_ffn_hidden_size} +num_experts: ${policy_num_experts} +moe_router_topk: ${policy_moe_router_topk} +tensor_model_parallel_size: ${policy_tp:1} +expert_model_parallel_size: ${policy_ep:8} +group_query_attention: ${group_query_attention:True} +num_query_groups: ${policy_num_query_groups} +use_distributed_optimizer: True diff --git a/examples/megatron/configs/mixtral/ppo_policy.yaml b/examples/megatron/configs/mixtral/ppo_policy.yaml new file mode 100644 index 00000000..66e145eb --- /dev/null +++ b/examples/megatron/configs/mixtral/ppo_policy.yaml @@ -0,0 +1,39 @@ +includes: + - base_train.yaml + - policy_shared.yaml + + +bf16: True +use_checkpoint_opt_param_scheduler: False +adam_beta1: 0.9 +adam_beta2: 0.95 +num_workers: 8 +init_method_std: 0.006 + +# dropout +attention_dropout: ${attention_dropout:0.1} +hidden_dropout: ${hidden_dropout:0.1} +retro_encoder_hidden_dropout: ${retro_encoder_hidden_dropout:0.1} +retro_encoder_attention_dropout: ${retro_encoder_attention_dropout:0.1} + +recompute_granularity: selective + +log_num_zeros_in_grad: True +no_load_optim: True +no_load_rng: True +no_load_args: True +no_load_scheduler: True + + +lr_decay_iters: 12000 +lr_warmup_iters: 100 +lr: ${policy_lr:2.4e-7} +min_lr: ${policy_min_lr:1e-9} +lr_decay_style: ${policy_lr_decay_style:linear} +weight_decay: 0.01 +pipeline_model_parallel_size: ${ppo_policy_pp:1} +sequence_parallel: ${sequence_parallel:True} + +recompute_activations: ${policy_recompute_activations:False} +recompute_granularity: ${policy_recompute_granularity:None} +moe_layer_recompute: ${policy_moe_layer_recompute:False} diff --git a/examples/megatron/configs/mixtral/ppo_value.yaml b/examples/megatron/configs/mixtral/ppo_value.yaml new file mode 100644 index 00000000..5424f897 --- /dev/null +++ b/examples/megatron/configs/mixtral/ppo_value.yaml @@ -0,0 +1,30 @@ +includes: + - base_train.yaml + - reward_shared.yaml + +pipeline_model_parallel_size: ${ppo_value_pp:1} +lr_decay_iters: 12000 +lr_warmup_iters: 100 +lr: ${value_lr:5e-6} +min_lr: ${value_min_lr:5e-7} +lr_decay_style: ${value_lr_decay_style:linear} +weight_decay: 0.01 +log_interval: 1 + +use_checkpoint_opt_param_scheduler: False +adam_beta1: 0.9 +adam_beta2: 0.95 +num_workers: 8 +init_method_std: 0.006 + +recompute_granularity: selective + +no_load_optim: True +no_load_rng: True +no_load_args: True +no_load_scheduler: True +sequence_parallel: True + +recompute_activations: ${value_recompute_activations:False} +recompute_granularity: ${value_recompute_granularity:None} +moe_layer_recompute: ${value_moe_layer_recompute:False} diff --git a/examples/megatron/configs/mixtral/reference.yaml b/examples/megatron/configs/mixtral/reference.yaml new file mode 100644 index 00000000..96cb77a2 --- /dev/null +++ b/examples/megatron/configs/mixtral/reference.yaml @@ -0,0 +1,6 @@ +includes: + - base_inference.yaml + - policy_shared.yaml + +parallel_output: True +pipeline_model_parallel_size: ${ref_pp:1} diff --git a/examples/megatron/configs/mixtral/reward_inference.yaml b/examples/megatron/configs/mixtral/reward_inference.yaml new file mode 100644 index 00000000..31e15f8b --- /dev/null +++ b/examples/megatron/configs/mixtral/reward_inference.yaml @@ -0,0 +1,6 @@ +includes: + - base_inference.yaml + - reward_shared.yaml + +reward_bias: 0 +pipeline_model_parallel_size: ${reward_pp:1} diff --git a/examples/megatron/configs/mixtral/reward_shared.yaml b/examples/megatron/configs/mixtral/reward_shared.yaml new file mode 100644 index 00000000..a50ded56 --- /dev/null +++ b/examples/megatron/configs/mixtral/reward_shared.yaml @@ -0,0 +1,18 @@ +load: ${reward_load} +load_iteration: ${reward_load_iteration} +# enable use_distributed_optimizer will raise error for reward/value, disable temply +use_distributed_optimizer: False + +num_layers: ${reward_num_layers} +hidden_size: ${reward_hidden_size} +num_attention_heads: ${reward_num_attention_heads} +ffn_hidden_size: ${reward_ffn_hidden_size} +num_experts: ${reward_num_experts} +moe_router_topk: ${reward_moe_router_topk} +tensor_model_parallel_size: ${reward_tp:1} +expert_model_parallel_size: ${reward_ep:8} +group_query_attention: ${group_query_attention:True} +num_query_groups: ${reward_num_query_groups} + +save_inference: True +save_inference_interval: 10 diff --git a/examples/megatron/configs/mixtral/rlhf.yaml b/examples/megatron/configs/mixtral/rlhf.yaml new file mode 100644 index 00000000..dbbfb17d --- /dev/null +++ b/examples/megatron/configs/mixtral/rlhf.yaml @@ -0,0 +1,83 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu_policy:16} + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + free_memory: ${free_memory_policy:False} + + reference: + model_config_file: reference.yaml + num_gpu: ${num_gpu_ref:16} + trainable: False + generation_batch_size: ${ref_generation_batch_size:4} + free_memory: ${free_memory_reference:False} + + reward: + model_config_file: reward_inference.yaml + num_gpu: ${num_gpu_reward:16} + trainable: False + free_memory: ${free_memory_reward:False} + + value: + model_config_file: old_value_inference.yaml + num_gpu: ${num_gpu_value:16} + trainable: False + free_memory: ${free_memory_value:False} + + ppo_policy: + model_config_file: ppo_policy.yaml + num_gpu: ${num_gpu_ppo_policy:16} + trainable: True + lora: + enable_lora: ${enable_lora_policy:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_policy:False} + + ppo_value: + model_config_file: ppo_value.yaml + num_gpu: ${num_gpu_ppo_value:16} + trainable: True + lora: + enable_lora: ${enable_lora_value:False} + lora_dim: 64 + lora_layer: ColumnParallelLinear,LinearLayer,RowParallelLinear + column_only_qkv: False + lora_dropout: 0.05 + free_memory: ${free_memory_ppo_value:False} + +runtime: + colocation: + - policy,ppo_policy,reward,reference,value,ppo_value + generation_batch_size: ${generation_batch_size:4} + train_micro_batch_size: ${train_micro_batch_size:2} + train_global_batch_size: ${train_global_batch_size:512} + num_episode: ${num_episode:200} + sample_per_episode: ${sample_per_episode:1024} + num_training_epoch: 1 + save_episode_interval: ${save_episode_interval:100} + data_path: ${data_path} + eval_data_path: ${eval_data_path} + training_data_num_limit: ${training_data_num_limit:-1} + eval_data_num_limit: ${eval_data_num_limit:128} + eval_episode_interval: ${eval_episode_interval:100} + data_checkpoint_path: ${data_checkpoint_path} + free_sync_collective_group: ${free_sync_collective_group:False} + exp_name: ${exp_name:chatlearn} + output_dir: ${output_dir} + debug: ${debug:False} + validate_param_sync: ${validate_param_sync:False} diff --git a/examples/megatron/configs/mixtral/test_policy.yaml b/examples/megatron/configs/mixtral/test_policy.yaml new file mode 100644 index 00000000..aa828c62 --- /dev/null +++ b/examples/megatron/configs/mixtral/test_policy.yaml @@ -0,0 +1,25 @@ +runtime_env: + platform: DLC + excludes: + - "*pt" + - "logs" + - "tensorboards" + - ".nfs*" + + +models: + policy: + model_config_file: old_policy_inference.yaml + num_gpu: ${num_gpu:1} + gpu_per_process: 1 + trainable: False + batch_generation: + ranking: ${batch_generation_ranking:False} + min_prompt_length: ${batch_generation_min_prompt_length:0} + +runtime: + generation_batch_size: ${generation_batch_size:4} + data_path: ${data_path} + eval_data_path: ${eval_data_path} + output_dir: ${output_dir} + exp_name: ${exp_name:chatlearn} diff --git a/examples/megatron/data/sft_dataset.py b/examples/megatron/data/sft_dataset.py index abc65c5d..df1c7ef1 100644 --- a/examples/megatron/data/sft_dataset.py +++ b/examples/megatron/data/sft_dataset.py @@ -57,6 +57,10 @@ def __init__(self, data_path, max_seq_length): self.pad_id = self.tokenizer.pad_token_id if hasattr(self.tokenizer, 'pad_token_id') else self.tokenizer.pad_id self.bos_id = self.tokenizer.bos_token_id if hasattr(self.tokenizer, 'bos_token_id') else self.tokenizer.bos_id self.eos_id = self.tokenizer.eod + # The pad_id for Llama2Tokenizer is -1. It will cause out-of-bound index for embedding calculation. + # Thus, we hardcode it as 0 here. + if self.pad_id == -1: + self.pad_id = 0 def __len__(self): return len(self.dataset) diff --git a/examples/megatron/models/forward_step.py b/examples/megatron/models/forward_step.py index 0a772c05..8da905d6 100644 --- a/examples/megatron/models/forward_step.py +++ b/examples/megatron/models/forward_step.py @@ -151,7 +151,7 @@ def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, recv_buffer = None if inference_config_master is not None and "DPO_labels" in inference_config_master: for key, value in inference_config_master.items(): - inference_config[key] = value[start:end, ...] + inference_config[key] = value[start:end, ...] if value is not None else None output = _forward_step_helper(model, tokens2use, position_ids2use, attention_mask, recv_buffer=recv_buffer, pooling_sequence_index=pooling_sequence_index2use, diff --git a/examples/megatron/scripts/base_env.sh b/examples/megatron/scripts/base_env.sh index aa1bccce..97d29d1c 100644 --- a/examples/megatron/scripts/base_env.sh +++ b/examples/megatron/scripts/base_env.sh @@ -130,6 +130,24 @@ elif [[ "$model_size" == "llama2-70B" ]]; then export reward_ffn_hidden_size=28672 export reward_num_query_groups=8 export group_query_attention=True +elif [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_num_layers=32 + export policy_hidden_size=4096 + export policy_num_attention_heads=32 + export policy_num_query_groups=8 + export policy_ffn_hidden_size=14336 + export policy_num_experts=8 + export policy_moe_router_topk=2 + export reward_num_layers=32 + export reward_hidden_size=4096 + export reward_num_attention_heads=32 + export reward_num_query_groups=8 + export reward_ffn_hidden_size=14336 + export reward_num_experts=8 + export reward_moe_router_topk=2 + export max_position_embedding=32768 + export seq_length=2048 + export USE_LEGACY_MODELS=False else echo "unsupported model_size ${model_size}, please set your own model config" exit 1 diff --git a/examples/megatron/scripts/convert_hf_to_megatron.sh b/examples/megatron/scripts/convert_hf_to_megatron.sh index 4de6c5b1..23354139 100644 --- a/examples/megatron/scripts/convert_hf_to_megatron.sh +++ b/examples/megatron/scripts/convert_hf_to_megatron.sh @@ -21,7 +21,11 @@ export PYTHONPATH=${megatron}:${chatlearn} load_dir=${LOAD_PATH} save_dir=${SAVE_PATH} tokenizer_model=${TOKENIZER_MODEL} -model_size=${model_size:-llama2-7B} +if [[ $model == 'gpt_llama' ]]; then + model_size=${model_size:-llama2-7B} +elif [[ $model == 'mixtral' ]]; then + model_size=${model_size:-mixtral-8x7B} +fi export CUDA_DEVICE_MAX_CONNECTIONS=1 @@ -54,7 +58,7 @@ elif [[ ${model} == 'mixtral' ]]; then cd ${megatron} python tools/checkpoint/convert.py \ --model-type GPT \ - --loader loader_mixtral_hf \ + --loader mixtral_hf \ --saver mcore \ --target-tensor-parallel-size ${tp} \ --target-pipeline-parallel-size ${pp} \ diff --git a/examples/megatron/scripts/convert_megatron_to_hf.sh b/examples/megatron/scripts/convert_megatron_to_hf.sh index e18c8cd4..63ad7f65 100644 --- a/examples/megatron/scripts/convert_megatron_to_hf.sh +++ b/examples/megatron/scripts/convert_megatron_to_hf.sh @@ -1,8 +1,9 @@ #!/bin/bash # Convert LLaMA model from megatron format to huggingface format. set -ex +set pipefail -# config +# path config chatlearn=${CHATLEARN} megatron=${MEGATRON} load_path=${LOAD_PATH} @@ -11,13 +12,29 @@ vocab_path=${VOCAB_PATH} target_params_dtype=${target_params_dtype:-bf16} temp_path=${save_path}/temp +# model config +# can be `gpt_llama' for GPT or Llama, or `mixtral' for Mixtral +model=${MODEL:-'gpt_llama'} + # Whether to use legacy models, default: True use_legacy_models=${USE_LEGACY_MODELS:-"True"} - -if [[ ${use_legacy_models} = "False" ]]; then - ckpt_format="mcore" +if [[ ${use_legacy_models} == "False" ]]; then + if [[ ${model} == 'gpt_llama' ]]; then + # TODO: migrate to mcore + loader_ckpt_format="mcore" + saver_ckpt_format="megatron" + elif [[ ${model} == 'mixtral' ]]; then + loader_ckpt_format="mcore_mixtral" + saver_ckpt_format="mcore" + else + echo -e "\033[31m Unrecognized model ${model} \033[0m" + exit -1 + fi + MCORE_ARGS="" else - ckpt_format="megatron" + loader_ckpt_format="megatron" + saver_ckpt_format="megatron" + MCORE_ARGS="--use_legacy_models" fi set +x @@ -25,18 +42,40 @@ set +x # convert parallel strategy START_TIME=$SECONDS -cd ${megatron} - -if [[ ! -d "${temp_path}" ]]; then - python tools/checkpoint/convert.py \ - --model-type GPT \ - --loader ${ckpt_format} \ - --saver "megatron" \ - --target-tensor-parallel-size 1 \ - --target-pipeline-parallel-size 1 \ - --load-dir ${load_path} \ - --save-dir ${temp_path} \ - --megatron-path ${megatron} +if [[ ! -d ${temp_path} ]]; then + if [[ ${model} == 'gpt_llama' ]]; then + cd ${megatron} + python tools/checkpoint/convert.py \ + --model-type GPT \ + --loader ${loader_ckpt_format} \ + --saver ${saver_ckpt_format} \ + --target-tensor-parallel-size 1 \ + --target-pipeline-parallel-size 1 \ + --load-dir ${load_path} \ + --save-dir ${temp_path} \ + --megatron-path ${megatron} + elif [[ ${model} == 'mixtral' ]]; then + cd ${chatlearn} + export PYTHONPATH=${chatlearn}:${megatron}:${megatron}/tools/checkpoint:${PYTHONPATH} + python chatlearn/tools/convert.py \ + --model-type GPT \ + --loader ${loader_ckpt_format} \ + --saver-prefix tools.checkpoint.saver \ + --saver ${saver_ckpt_format} \ + --target-tensor-parallel-size 1 \ + --target-pipeline-parallel-size 1 \ + --target-expert-parallel-size 1 \ + --load-dir ${load_path} \ + --save-dir ${temp_path} \ + --megatron-path ${megatron} + else + echo -e "\033[31m Unrecognized model ${model} \033[0m" + exit -1 + fi +fi + +if [[ $? != 0 ]]; then + exit $? fi # convert to hf format @@ -46,7 +85,8 @@ python chatlearn/tools/megatron_to_hf.py \ --save_path ${save_path} \ --target_params_dtype ${target_params_dtype} \ --vocab_dir ${vocab_path} \ - --megatron_path ${megatron} + --megatron_path ${megatron} \ + ${MCORE_ARGS} # clear temp path rm -r $temp_path diff --git a/examples/megatron/scripts/train_dpo_mixtral.sh b/examples/megatron/scripts/train_dpo_mixtral.sh new file mode 100644 index 00000000..2d81a268 --- /dev/null +++ b/examples/megatron/scripts/train_dpo_mixtral.sh @@ -0,0 +1,53 @@ +#!/bin/bash +set -x + +[ -z "$model_size" ] && export model_size=mixtral-8x7B + +# Get the directory of the current script +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source ${DIR}/base_env.sh + +export trainer_engine=dpo + +# clip +export clip_grad=5.0 + +# desable dropout +export attention_dropout=0.0 +export hidden_dropout=0.0 +export retro_encoder_hidden_dropout=0.0 +export retro_encoder_attention_dropout=0.0 + + +if [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_tp=1 + export policy_ep=8 + export ppo_policy_pp=4 + export ref_pp=4 + export train_global_batch_size=128 + export ref_generation_batch_size=2 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True +fi + +configs=$CHATLEARN/examples/megatron/configs/mixtral/dpo.yaml + +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 + +output_dir=${output_dir}/${exp_name} +export data_checkpoint_path=${output_dir}/data_checkpoint +mkdir -p $output_dir +log_file=${output_dir}/log_${RANK}.log + +policy_inference_load=${POLICY_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_gpu=${num_gpu} \ +data_path=${DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +python entry/train_dpo.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + + diff --git a/examples/megatron/scripts/train_online_dpo_mixtral.sh b/examples/megatron/scripts/train_online_dpo_mixtral.sh new file mode 100644 index 00000000..06e83008 --- /dev/null +++ b/examples/megatron/scripts/train_online_dpo_mixtral.sh @@ -0,0 +1,86 @@ +#!/bin/bash +set -x + +[ -z "$model_size" ] && export model_size=mixtral-8x7B + +# Get the directory of the current script +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source ${DIR}/base_env.sh + +# megatron +# TODO: support vllm +backend=${1:-megatron} + +if [[ "$backend" != "megatron" ]]; then + echo "ERROR: expect megatron backend for Mixtral models, while current backend is "$backend + exit 1 +fi + +if [[ "$backend" == "megatron" ]]; then + configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo.yaml +else + export ENABLE_VLLM=True + if [ -z "$tokenizer_load" ];then + echo "please set path to hf tokenizer for vllm backend, download from huggingface source." + exit 1 + fi + configs=$CHATLEARN/examples/megatron/configs/mixtral/online_dpo_vllm.yaml +fi + +export trainer_engine=online_dpo + +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend + +output_dir=$output_dir/$exp_name +export data_checkpoint_path=${output_dir}/data_checkpoint + +export train_to_compare_num_responses=8 +export num_inference_per_prompt=8 + + +if [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_tp=1 + export policy_ep=8 + export ppo_policy_pp=4 + export reward_tp=1 + export reward_ep=8 + export ppo_value_pp=4 + export ref_pp=4 + export policy_pp=4 + export reward_pp=4 + export value_pp=4 + export train_global_batch_size=32 + export generation_batch_size=8 + export ref_generation_batch_size=8 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True + export value_recompute_activations=True + export value_moe_layer_recompute=True + export free_memory_policy=True + export free_memory_reference=True + export free_memory_reward=True + export free_memory_value=True + export free_memory_ppo_policy=True + export free_memory_ppo_value=True + export seq_length=2048 + export max_new_tokens=1024 +fi + +mkdir -p ${output_dir} +log_file=${output_dir}/log_${RANK}.log + +policy_inference_load=${POLICY_LOAD} \ +reward_load_iteration=${REWARD_LOAD_ITERATION} \ +reward_load=${REWARD_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_gpu=${num_gpu} \ +data_path=${DATASET_PATH} \ +eval_data_path=${EVAL_DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +python entry/train_online_dpo.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} + diff --git a/examples/megatron/scripts/train_reward_mixtral.sh b/examples/megatron/scripts/train_reward_mixtral.sh new file mode 100644 index 00000000..3694cd80 --- /dev/null +++ b/examples/megatron/scripts/train_reward_mixtral.sh @@ -0,0 +1,165 @@ +#!/bin/bash +set -x + +[ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost +[ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1 +[ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=8 +[ -z "$RANK" ] && export RANK=0 +[ -z "$MASTER_PORT" ] && export MASTER_PORT=12456 + +DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ + --nnodes ${WORLD_SIZE} \ + --node_rank ${RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT}" + +# check the path +[[ -z "${MEGATRON}" ]] && { echo "MEGATRON path is not set"; exit 1; } +[[ -z "${CHATLEARN}" ]] && { echo "CHATLEARN path is not set"; exit 1; } +[[ -z "${LOAD_PATH}" ]] && { echo "LOAD_PATH is not set"; exit 1; } +[[ -z "${TOKENIZER_MODEL}" ]] && { echo "TOKENIZER_MODEL is not set"; exit 1; } +[[ -z "${DATASET_PATH}" ]] && { echo "DATASET_PATH is not set"; exit 1; } + + +export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHATLEARN} + +[ -z "$model_size" ] && export model_size="mixtral-8x7B" + +if [[ $model_size == "mixtral-8x7B" ]]; then + NUM_LAYERS=32 + HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=32 + FFN_HIDDEN_SIZE=14336 + MAX_POSITION_EMBEDDINGS=32768 + NUM_QUERY_GROUPS=8 + NUM_EXPERTS=8 + MOE_ROUTER_TOPK=2 + seq_length=2048 + tp=1 + pp=4 + ep=8 + mb=1 + gbs=8 +else + echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." + exit -1 +fi + +DIR=$(pwd) +DATETIME=$(date +'date_%y-%m-%d_time_%H-%M-%S') +mkdir -p $DIR/logs + +NODE_RANK=$RANK +NNODES=$WORLD_SIZE + + +dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp)) +gbs=$(($gbs * $dp)) + + +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/reward/mixtral_hh_reward_$(date +%F)_gpt_${model_size}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_length} + +mkdir -p $CHECKPOINT_PATH + +MODEL_ARGS=" +--disable-bias-linear \ +--seq-length $seq_length \ +--max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ +--num-layers ${NUM_LAYERS} \ +--hidden-size ${HIDDEN_SIZE} \ +--ffn-hidden-size ${FFN_HIDDEN_SIZE} \ +--num-attention-heads ${NUM_ATTN_HEADS} \ +--init-method-std 0.01 \ +--attention-dropout 0.0 \ +--hidden-dropout 0.0 \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--swiglu \ +--untie-embeddings-and-output-weights \ +--group-query-attention \ +--num-query-groups ${NUM_QUERY_GROUPS} \ +--no-masked-softmax-fusion \ +--no-position-embedding \ +--transformer-impl transformer_engine \ +--attention-softmax-in-fp32 " + +MOE_ARGS=" +--num-experts ${NUM_EXPERTS} \ +--moe-router-topk ${MOE_ROUTER_TOPK} \ +--moe-router-load-balancing-type aux_loss \ +--moe-aux-loss-coeff 1e-2 \ +--moe-token-dispatcher-type alltoall \ +--overlap-param-gather \ +--overlap-grad-reduce \ +--moe-layer-recompute" + +DATA_ARGS=" +--tokenizer-type Llama2Tokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--data-path $DATASET_PATH/train.jsonl $DATASET_PATH/dev.jsonl $DATASET_PATH/dev.jsonl \ +--split 98,2,0 \ +--dataloader-type cyclic " + +TRAINING_ARGS=" +--micro-batch-size $mb \ +--global-batch-size $gbs \ +--lr 3e-6 \ +--train-iters 1000 \ +--lr-decay-iters 1000 \ +--lr-decay-style cosine \ +--min-lr 1.0e-12 \ +--weight-decay 0.1 \ +--lr-warmup-iters 300 \ +--clip-grad 1.0 \ +--bf16 \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--adam-beta1 0.9 \ +--adam-beta2 0.95 \ +--use-flash-attn \ +--finetune \ +--recompute-activations \ +--max-response 2 \ +--select-max-response firstk " + +MODEL_PARALLEL_ARGS=" +--tensor-model-parallel-size $tp \ +--pipeline-model-parallel-size $pp \ +--expert-model-parallel-size $ep \ +--use-distributed-optimizer \ +--sequence-parallel \ +--distributed-timeout-minutes 60 \ +" + +LOGGING_ARGS=" +--log-interval 1 \ +--eval-iters 20 \ +--eval-interval 1000 \ +--save-interval 1000 \ +--save $CHECKPOINT_PATH \ +--load $LOAD_PATH \ +--auto-detect-ckpt-format \ +--num-workers 8 \ +--no-load-rng \ +--no-load-optim \ +--tensorboard-dir $CHECKPOINT_PATH \ +--tensorboard-log-interval 10 \ +--log-timers-to-tensorboard \ +--log-batch-size-to-tensorboard \ +--log-validation-ppl-to-tensorboard \ +" + +log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +cd ${CHATLEARN}/examples/megatron/ + +torchrun $DISTRIBUTED_ARGS \ + entry/train_reward.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/scripts/train_rlhf_mixtral.sh b/examples/megatron/scripts/train_rlhf_mixtral.sh new file mode 100644 index 00000000..49f31f8b --- /dev/null +++ b/examples/megatron/scripts/train_rlhf_mixtral.sh @@ -0,0 +1,88 @@ +#!/bin/bash +set -x + +[ -z "$model_size" ] && export model_size=mixtral-8x7B + +# Get the directory of the current script +DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +source ${DIR}/base_env.sh + +# megatron +# TODO: support vllm +backend=${1:-megatron} +if [[ "$backend" != "megatron" ]]; then + echo "ERROR: expect megatron backend for Mixtral models, while current backend is "$backend + exit 1 +fi + + +config_dir=${CHATLEARN}/examples/megatron/configs/ + +if [[ "$backend" == "megatron" ]]; then + configs=${config_dir}/mixtral/rlhf.yaml +else + export ENABLE_VLLM=True + if [ -z "$tokenizer_load" ];then + echo "please set path to hf tokenizer for vllm backend, download from huggingface source." + exit 1 + fi + configs=${config_dir}/mixtral/vllm_rlhf.yaml +fi + +export trainer_engine=rlhf + +[ -z "$exp_name" ] && export exp_name=$(date +%F)-${model_size}-${trainer_engine} +[ -z "$output_dir" ] && export output_dir=${CHATLEARN}/output/ +[ -z "$sample_per_episode" ] && sample_per_episode=1024 +[ -z "$tokenizer_load" ] && export tokenizer_load=path-to-hf-tokenizer-for-vllm-backend + +output_dir=${output_dir}/${exp_name} +export data_checkpoint_path=${output_dir}/data_checkpoint + + +if [[ "$model_size" == "mixtral-8x7B" ]]; then + export policy_tp=1 + export policy_ep=8 + export ppo_policy_pp=4 + export reward_tp=1 + export reward_ep=8 + export ppo_value_pp=4 + export ref_pp=4 + export policy_pp=4 + export reward_pp=4 + export value_pp=4 + export train_global_batch_size=32 + export generation_batch_size=8 + export ref_generation_batch_size=8 + export train_micro_batch_size=1 + export policy_recompute_activations=True + export policy_moe_layer_recompute=True + export value_recompute_activations=True + export value_moe_layer_recompute=True + export free_memory_policy=True + export free_memory_reference=True + export free_memory_reward=True + export free_memory_value=True + export free_memory_ppo_policy=True + export free_memory_ppo_value=True + export seq_length=2048 + export max_new_tokens=1024 +else + echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." + exit -1 +fi + +mkdir -p ${output_dir} +log_file=${output_dir}/log_${RANK}.log +echo $log_file + +policy_inference_load=${POLICY_LOAD} \ +reward_load_iteration=${REWARD_LOAD_ITERATION} \ +reward_load=${REWARD_LOAD} \ +tokenizer_model=${TOKENIZER_MODEL} \ +num_gpu=${num_gpu} \ +data_path=${DATASET_PATH} \ +eval_data_path=${EVAL_DATASET_PATH} \ +sample_per_episode=${sample_per_episode} \ +python entry/train_rlhf.py -c $configs 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/scripts/train_sft_mixtral.sh b/examples/megatron/scripts/train_sft_mixtral.sh new file mode 100644 index 00000000..e3eb7fcc --- /dev/null +++ b/examples/megatron/scripts/train_sft_mixtral.sh @@ -0,0 +1,161 @@ +#!/bin/bash +set -x + +[ -z "$MASTER_ADDR" ] && export MASTER_ADDR=localhost +[ -z "$WORLD_SIZE" ] && export WORLD_SIZE=1 +[ -z "$GPUS_PER_NODE" ] && export GPUS_PER_NODE=8 +[ -z "$RANK" ] && export RANK=0 +[ -z "$MASTER_PORT" ] && export MASTER_PORT=12456 + +DISTRIBUTED_ARGS="--nproc_per_node ${GPUS_PER_NODE} \ + --nnodes ${WORLD_SIZE} \ + --node_rank ${RANK} \ + --master_addr ${MASTER_ADDR} \ + --master_port ${MASTER_PORT}" + +# check the path +[[ -z "${MEGATRON}" ]] && { echo "MEGATRON path is not set"; exit 1; } +[[ -z "${CHATLEARN}" ]] && { echo "CHATLEARN path is not set"; exit 1; } +[[ -z "${LOAD_PATH}" ]] && { echo "LOAD_PATH is not set"; exit 1; } +[[ -z "${TOKENIZER_MODEL}" ]] && { echo "TOKENIZER_MODEL is not set"; exit 1; } +[[ -z "${DATASET_PATH}" ]] && { echo "DATASET_PATH is not set"; exit 1; } + + +export PYTHONPATH=${PYTHONPATH}:${MEGATRON}:${CHATLEARN}/examples/megatron:${CHATLEARN} + +[ -z "$model_size" ] && export model_size="mixtral-8x7B" + +if [[ $model_size == "mixtral-8x7B" ]]; then + NUM_LAYERS=32 + HIDDEN_SIZE=4096 + NUM_ATTN_HEADS=32 + FFN_HIDDEN_SIZE=14336 + MAX_POSITION_EMBEDDINGS=32768 + NUM_QUERY_GROUPS=8 + NUM_EXPERTS=8 + MOE_ROUTER_TOPK=2 + seq_length=2048 + tp=1 + pp=4 + ep=8 + mb=1 + gbs=8 +else + echo "Unrecognized model_size ${model_size}, choose from 'mixtral-8x7B'." + exit -1 +fi + +DIR=$(pwd) +DATETIME=$(date +'date_%y-%m-%d_time_%H-%M-%S') +mkdir -p $DIR/logs + +NODE_RANK=$RANK +NNODES=$WORLD_SIZE + + +dp=$(($WORLD_SIZE * $GPUS_PER_NODE / $tp / $pp)) +gbs=$(($gbs * $dp)) + + +[ -z "$CHECKPOINT_PATH" ] && CHECKPOINT_PATH=${CHATLEARN}/output/sft/mixtral_hh_sft_$(date +%F)_gpt_${model_size}_${NNODES}w${GPUS_PER_NODE}g_tp${tp}_pp${pp}_ep${ep}_mb${mb}_seqlen${seq_length} + +mkdir -p $CHECKPOINT_PATH + +MODEL_ARGS=" +--disable-bias-linear \ +--seq-length $seq_length \ +--max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \ +--num-layers ${NUM_LAYERS} \ +--hidden-size ${HIDDEN_SIZE} \ +--ffn-hidden-size ${FFN_HIDDEN_SIZE} \ +--num-attention-heads ${NUM_ATTN_HEADS} \ +--init-method-std 0.01 \ +--attention-dropout 0.0 \ +--hidden-dropout 0.0 \ +--normalization RMSNorm \ +--position-embedding-type rope \ +--swiglu \ +--untie-embeddings-and-output-weights \ +--group-query-attention \ +--num-query-groups ${NUM_QUERY_GROUPS} \ +--no-masked-softmax-fusion \ +--no-position-embedding \ +--transformer-impl transformer_engine \ +--attention-softmax-in-fp32 " + +MOE_ARGS=" +--num-experts ${NUM_EXPERTS} \ +--moe-router-topk ${MOE_ROUTER_TOPK} \ +--moe-router-load-balancing-type aux_loss \ +--moe-aux-loss-coeff 1e-2 \ +--moe-token-dispatcher-type alltoall \ +--overlap-param-gather \ +--overlap-grad-reduce " + +DATA_ARGS=" +--tokenizer-type Llama2Tokenizer \ +--tokenizer-model ${TOKENIZER_MODEL} \ +--data-path $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl $DATASET_PATH/train.jsonl \ +--split 98,2,0 \ +--dataloader-type cyclic " + +TRAINING_ARGS=" +--micro-batch-size $mb \ +--global-batch-size $gbs \ +--lr 1e-5 \ +--train-iters 1000 \ +--lr-decay-iters 1000 \ +--lr-decay-style cosine \ +--min-lr 1.0e-7 \ +--weight-decay 0.01 \ +--lr-warmup-iters 50 \ +--clip-grad 1.0 \ +--bf16 \ +--exit-on-missing-checkpoint \ +--use-checkpoint-args \ +--adam-beta1 0.9 \ +--adam-beta2 0.999 \ +--use-flash-attn \ +--finetune " + +MODEL_PARALLEL_ARGS=" +--tensor-model-parallel-size $tp \ +--pipeline-model-parallel-size $pp \ +--expert-model-parallel-size $ep \ +--use-distributed-optimizer \ +--sequence-parallel \ +--distributed-timeout-minutes 60 \ +" + +LOGGING_ARGS=" +--log-interval 1 \ +--eval-iters 10 \ +--eval-interval 100 \ +--save-interval 1000 \ +--save $CHECKPOINT_PATH \ +--load $LOAD_PATH \ +--auto-detect-ckpt-format \ +--num-workers 8 \ +--no-load-rng \ +--no-load-optim \ +--tensorboard-dir $CHECKPOINT_PATH \ +--tensorboard-log-interval 10 \ +--log-timers-to-tensorboard \ +--log-batch-size-to-tensorboard \ +--log-validation-ppl-to-tensorboard \ +" + +log_file=$CHECKPOINT_PATH/stderr_$NODE_RANK.log + +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +cd ${CHATLEARN}/examples/megatron/ + +torchrun $DISTRIBUTED_ARGS \ + entry/train_sft.py \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} 2>&1 | tee -a ${log_file} ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/tests/run_policy_generation.sh b/examples/megatron/tests/run_policy_generation.sh index fc50bf77..91f799a7 100644 --- a/examples/megatron/tests/run_policy_generation.sh +++ b/examples/megatron/tests/run_policy_generation.sh @@ -3,8 +3,10 @@ set -x [ -z "$MEGATRON" ] && export MEGATRON=path-to-megatron [ -z "$CHATLEARN" ] && export CHATLEARN=path-to-chatlearn -[ -z "$TP" ] && export TP=4 +[ -z "$num_gpu" ] && export num_gpu=8 [ -z "$PP" ] && export PP=2 +[ -z "$EP" ] && export EP=1 +[ -z "$TP" ] && export TP=4 [ -z "$VOCAB_FILE" ] && export VOCAB_FILE=path-to-tokenizer [ -z "$LOAD" ] && export LOAD=path-to-ckpt [ -z "$DATASET_PATH" ] && export DATASET_PATH=path-to-dataset-json @@ -39,6 +41,14 @@ elif [[ $model_size == "llama2"* ]]; then configs=configs/llama2/test_vllm_policy.yaml fi export tokenizer_model=$VOCAB_FILE +elif [[ $model_size == "mixtral"* ]]; then + if [[ "$backend" == "megatron" ]]; then + configs=configs/mixtral/test_policy.yaml + else + echo "ERROR: mixtral model support megatron backend currently." + exit 1 + fi + export tokenizer_model=$VOCAB_FILE else echo "unexpected model_type $model_size." exit 1 @@ -55,10 +65,12 @@ log_file=${output_dir}/log_${RANK}.log export batch_generation_min_prompt_length=32 -generation_batch_size=64 \ + +generation_batch_size=${generation_batch_size:-64} \ num_gpu=${num_gpu:-8} \ -policy_tp=$TP \ policy_pp=$PP \ +policy_ep=$EP \ +policy_tp=$TP \ eval_data_path=$DATASET_PATH \ policy_inference_load=$LOAD \ python tests/test_policy_generation.py -c $configs 2>&1 | tee ${log_file}.log ; exit ${PIPESTATUS[0]} diff --git a/examples/megatron/tests/test_checkpoint_conversion.py b/examples/megatron/tests/test_checkpoint_conversion.py new file mode 100644 index 00000000..1625bd39 --- /dev/null +++ b/examples/megatron/tests/test_checkpoint_conversion.py @@ -0,0 +1,54 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""test checkpoint conversion between megatron (legacy or mcore) and huggingface""" + +import argparse + +import torch +from transformers import AutoModel + +def extract_name_and_params(model): + name_list = [] + param_list = [] + model_named_parameters = model.named_parameters() + for name, param in model_named_parameters: + name_list.append(name) + param_list.append(param) + return name_list, param_list + +def compare_checkpoint(src_path, dst_path): + src_model = AutoModel.from_pretrained(src_path) + dst_model = AutoModel.from_pretrained(dst_path) + src_model_names, src_model_params = extract_name_and_params(src_model) + dst_model_names, dst_model_params = extract_name_and_params(dst_model) + assert src_model_names == dst_model_names + for i, (src_param, dst_param) in enumerate(zip(src_model_params, dst_model_params)): + print(f"Comparing {src_model_names[i]}") + assert torch.equal(src_param, dst_param), f"Parameter {src_model_names[i]} is not equal for two models." + return True + +def main(): + parser = argparse.ArgumentParser() + + parser.add_argument('--src-path', type=str, required=True, + help='source huggingface checkpoint path') + parser.add_argument('--dst-path', type=str, required=True, + help='destinate hugginface checkpoint path') + args = parser.parse_args() + + return compare_checkpoint(args.src_path, args.dst_path) + +if __name__ == '__main__': + main() diff --git a/examples/megatron/tests/test_checkpoint_conversion.sh b/examples/megatron/tests/test_checkpoint_conversion.sh new file mode 100644 index 00000000..62935c09 --- /dev/null +++ b/examples/megatron/tests/test_checkpoint_conversion.sh @@ -0,0 +1,48 @@ +#!/bin/bash +set -exo pipefail + +export CHATLEARN=${CHATLEARN:-"path-to-chatlearn"} +export MEGATRON=${MEGATRON:-"path-to-megatron-lm"} +export LOAD_PATH=${LOAD_PATH:-"path-to-hf-ckpt"} +export TEMP_PATH=${TEMP_PATH:-"path-to-converted-mg-ckpt"} +export SAVE_PATH=${SAVE_PATH:-"path-to-converted-back-hf-ckpt"} +export VOCAB_PATH=${VOCAB_PATH:-"path-to-vocabulary"} +export TOKENIZER_MODEL=${TOKENIZER_MODEL:-"path-to-tokenizer-model"} + +export MODEL=${MODEL:-"mixtral"} +export USE_LEGACY_MODELS=${USE_LEGACY_MODELS:-"False"} + +# Step 1: Convert to Megatron checkpoint + +cd $CHATLEARN/examples/megatron/ + +TP=1 \ +PP=4 \ +EP=8 \ +LOAD_PATH=${LOAD_PATH} \ +SAVE_PATH=${TEMP_PATH} \ +bash scripts/convert_hf_to_megatron.sh + +# Step 2: Convert to HuggingFace checkpoint + +LOAD_PATH=${TEMP_PATH} \ +SAVE_PATH=${SAVE_PATH} \ +VOCAB_PATH=${VOCAB_PATH} \ +target_params_dtype=bf16 \ +bash scripts/convert_megatron_to_hf.sh + +# Step 3: Compare converted hf ckpt against the original hf ckpt + +python3 tests/test_checkpoint_conversion.py \ + --src-path ${LOAD_PATH} \ + --dst-path ${SAVE_PATH} + +if [[ $? != 0 ]]; then + echo -e "\033[31m Test failed! \033[0m" + exit -1 +fi + +rm -rf ${TEMP_PATH} +rm -rf ${SAVE_PATH} + +echo "Test success!"