diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 2aed749b..ed21b5d9 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -29,8 +29,8 @@ from chatlearn.data.sampler import SingleDataSampler, EpisodeDataSampler from chatlearn.checkpoint.checkpoint_manager import CheckpointManager from chatlearn.utils import future -from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense -from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage +from chatlearn.utils.dist_utils import bucket_tensors, bucket_tensor_generator, coalesced_comm_dense +from chatlearn.utils.dist_utils import bucket_tensors_two_stage, bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage from chatlearn.utils.global_vars import get_args from chatlearn.utils.global_vars import set_global_variables from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger @@ -766,10 +766,8 @@ def allgather_routed_expert_parameter(self, group_name, pipe_stage=0): self._expert_sync_buffer.pop(name, "Not Found.") self._expert_sync_buffer[name] = param - def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): - """ - :meta private: - """ + def _broadcast_parameter_opt_level_0(self, rank, src_rank, group_name, pipe_stage=0): + debug_rank_0(">>>>>>>>>>>>>>>>broadcast parameter at memory optimization level 0") tensors = [] for name, param in self._parameters_to_sync[pipe_stage]: if self._expert_sync_buffer and name in self._expert_sync_buffer and \ @@ -788,32 +786,113 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): for param in sparse_bucket: col.broadcast(param, src_rank, group_name) + self.empty_cache() - def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): - """ - Arguments: - to_rank: receive rank in mapping from trainer to inference model. - buffer_rank: index which tensors of sync buffer to be sended in stage2. - rank: destination rank in communication group which enumerate receive ranks. - src_rank: source rank in communication group. always 0. - group_name: communication group name. - pipe_stage: pipeline stage. default 0. - stage2: bool. whether stage2 or not. default False. - Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1 - stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)] - stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)] + def _broadcast_parameter_opt_level_1(self, rank, src_rank, group_name, pipe_stage=0): + debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1") + def tensor_generator(): + for name, param in self._parameters_to_sync[pipe_stage]: + if self._expert_sync_buffer and name in self._expert_sync_buffer: + yield self._expert_sync_buffer[name] + # move self._expert_sync_buffer[name] to cpu mem to save gpu mem + cpu_expert = self._expert_sync_buffer[name].cpu() + del self._expert_sync_buffer[name] + self._expert_sync_buffer[name] = cpu_expert + else: + yield param.data - For stage1 pair (0, 8): - 1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0. - 2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1. + bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) + dense_bucket_num = 0 + sparse_bucket_num = 0 + tensor_changed = rank != src_rank + for bucket_or_tensor, is_dense in bucket_generator: + if is_dense: + coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + dense_bucket_num += 1 + else: + col.broadcast(bucket_or_tensor, src_rank, group_name) + sparse_bucket_num += 1 - After (0, 8), to_rank 8 received tensor slices of 8 and 9. + debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) + self.empty_cache() - For stage2 pair (8, 9): - 1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0. - 2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1. - In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer. + def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ + :meta private: + """ + if self.runtime_args.sync_memory_optimization_level == 0: + self._broadcast_parameter_opt_level_0(rank, src_rank, group_name, pipe_stage) + else: + self._broadcast_parameter_opt_level_1(rank, src_rank, group_name, pipe_stage) + + def _broadcast_parameter_two_stage_opt_level_0(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 0") + tensor_changed = rank != src_rank + + if stage2: + if tensor_changed: + parameters_to_sync = self._parameters_to_recv[to_rank] + else: + parameters_to_sync = self._parameters_to_send + else: + del self._sync_buffer + self._sync_buffer = defaultdict(list) + parameters_to_sync = self._parameters_to_sync + + tensors = [] + buffer_num = [] + if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks + idx = 0 + for name, param in parameters_to_sync[pipe_stage]: + self._logger.debug( + f"Adding {name} to sync for if branch from " + f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}" + ) + tensors.append(self._sync_buffer[buffer_rank % self.tp_num_mapping][idx]) + buffer_num.append(1) + idx += 1 + del self._sync_buffer[buffer_rank % self.tp_num_mapping] + else: + for name, param in parameters_to_sync[pipe_stage]: + self._logger.debug( + f"Adding {name} to sync for else branch from " + f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}" + ) + param_data = param.data + if rank and self._buffer_num and not stage2: + assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}" + buffer_num.append(self._buffer_num[name]) + elif stage2: + buffer_num.append(1) + else: + # regroup src_tensor by tp_rank. + param_data = self._synchronizer.regroup_params_to_sync(name, param_data, self._tp_division[name]) + buffer_num.append(1) + tensors.append(param_data) + + assert len(tensors) > 0 + dense_buckets, sparse_bucket = bucket_tensors_two_stage( + tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb, + buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2) + debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, sparse_bucket {len(sparse_bucket)}", self._logger) + + for bucket in dense_buckets: + index = 0 if stage2 else (to_rank % self.tp_num_mapping) + all_buffers = coalesced_comm_dense_two_stage( + bucket, col.broadcast, rank, + extra_args=(src_rank, group_name), tensor_changed=tensor_changed, + stage2=stage2, index=index) + if tensor_changed and not stage2: + for key, value in all_buffers.items(): + self._sync_buffer[key] += value + + for param in sparse_bucket: + col.broadcast(param, src_rank, group_name) + + self.empty_cache() + + def _broadcast_parameter_two_stage_opt_level_1(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1") tensor_changed = rank != src_rank if stage2: @@ -904,6 +983,36 @@ def tensor_generator(): self.empty_cache() + def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + """ + Arguments: + to_rank: receive rank in mapping from trainer to inference model. + buffer_rank: index which tensors of sync buffer to be sended in stage2. + rank: destination rank in communication group which enumerate receive ranks. + src_rank: source rank in communication group. always 0. + group_name: communication group name. + pipe_stage: pipeline stage. default 0. + stage2: bool. whether stage2 or not. default False. + Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1 + stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)] + stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)] + + For stage1 pair (0, 8): + 1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0. + 2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1. + + After (0, 8), to_rank 8 received tensor slices of 8 and 9. + + For stage2 pair (8, 9): + 1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0. + 2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1. + In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer. + """ + if self.runtime_args.sync_memory_optimization_level == 0: + self._broadcast_parameter_two_stage_opt_level_0(to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage, stage2) + else: + self._broadcast_parameter_two_stage_opt_level_1(to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage, stage2) + def send_parameter(self, dst_rank, group_name, pipe_stage=0): """ :meta private: diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 330c9c7f..3195d03b 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -296,9 +296,9 @@ def learn(self): self.timers("sync_parameters").start() self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync) self.timers("sync_parameters").stop() + logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync')) logger.info( - f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \ - + get_full_proc_memory_info('After first param sync') + f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " ) self._data_loader = data_loader for episode_id in range(self._start_episode, self.runtime_args.num_episode): diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index c7d89e50..4e5fe5bb 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -312,6 +312,8 @@ class RuntimeConfig(BaseConfig): param_sync_max_workers: int = None #: communication type to regroup routed experts, allgather/alltoall routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL + #: memory optimization level in synchronization to decide whether save gpu memory or persue faster execution runtime, 0/1 + sync_memory_optimization_level: int = 1 #: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes #: if `max_relay_episode` is set to 0, then relay is disabled max_relay_episode: int = 0 @@ -510,6 +512,7 @@ def _validate_params(self): assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"] assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY] assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE) + assert self.runtime_args.sync_memory_optimization_level in [0, 1] for model_name, model_args in self.models.items(): if model_args.num_gpu >= 1: if model_args.gpu_per_process is None: diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index df682d12..d78972cd 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -53,6 +53,74 @@ def bucket_tensors(tensors, bucket_size_mb): return dense_buckets, sparse_bucket +def bucket_tensor_generator(tensor_generator, bucket_size_mb): + """Group tensors into chunks. We seperate sparse and dense tensor, + each containing tensors of same type up to certain byte limit in total size. + + Args: + tensor_generator (Generator): A generator of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + + Yield: + dense_buckets: Blocks of tensors of same type and within size_limit. + sparse_bucket: A list of sparse tensors + """ + size_limit = bucket_size_mb * 1024 * 1024 + buf_dict = defaultdict(lambda: [[], 0]) + for tensor in tensor_generator(): + if tensor.is_sparse: + yield tensor, False + continue + t = tensor.type() + size = tensor.numel() * tensor.element_size() + buf_and_size = buf_dict[t] + if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison + yield buf_and_size[0], True + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append(tensor) + buf_and_size[1] += size + for buf, _ in buf_dict.values(): + if len(buf) > 0: + yield buf, True + + +def bucket_tensors_two_stage(tensors, bucket_size_mb, buffer_num=None, tensor_changed=False): + """Group tensors into chunks. We seperate sparse and dense tensor, + each containing tensors of same type up to certain byte limit in total size. + Args: + tensors (Sequence): A sequence of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + Return: + dense_buckets: Blocks of tensors of same type and within size_limit. + sparse_bucket: A list of sparse tensors + """ + size_limit = bucket_size_mb * 1024 * 1024 + buf_dict = defaultdict(lambda: [[], 0]) + dense_buckets = [] + sparse_bucket = [] + for idx, tensor in enumerate(tensors): + buffer_multiple = 1 if buffer_num is None else buffer_num[idx] + if tensor.is_sparse: + sparse_bucket.append(tensor) + continue + t = tensor.type() + # expand buffer size of dst ranks which recv tensor from trainer. + size = tensor.numel() * tensor.element_size() * buffer_multiple + buf_and_size = buf_dict[t] + if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison + dense_buckets.append(buf_and_size[0]) + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append((torch.empty(size=[tensor.numel() * buffer_multiple], + dtype=tensor.dtype, + device=tensor.device) if (tensor_changed and buffer_multiple > 1) else tensor, + [size // tensor.element_size(), buffer_multiple, tensor])) + buf_and_size[1] += size + for buf, size in buf_dict.values(): + if len(buf) > 0: + dense_buckets.append(buf) + return dense_buckets, sparse_bucket + + def bucket_tensors_two_stage_generator(tensor_generator, bucket_size_mb, stage2=False, tensor_changed=False): """Group tensors into chunks. We seperate sparse and dense tensor, each containing tensors of same type up to certain byte limit in total size. @@ -130,8 +198,11 @@ def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True): flat_tensors = _flatten_dense_tensors(bucket) comm_call(flat_tensors, *extra_args) if tensor_changed: + all_buffers = _unflatten_dense_tensors(flat_tensors, bucket) + del flat_tensors for tensor, synced in zip( - bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + bucket, all_buffers + ): tensor.copy_(synced) diff --git a/examples/megatron/configs/llama2/grpo_math_vllm.yaml b/examples/megatron/configs/llama2/grpo_math_vllm.yaml index 64f54bfd..2a7b933b 100644 --- a/examples/megatron/configs/llama2/grpo_math_vllm.yaml +++ b/examples/megatron/configs/llama2/grpo_math_vllm.yaml @@ -66,3 +66,4 @@ runtime: max_relay_episode: 1 exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/online_dpo_vllm.yaml b/examples/megatron/configs/llama2/online_dpo_vllm.yaml index e64f78ee..f274ee6b 100644 --- a/examples/megatron/configs/llama2/online_dpo_vllm.yaml +++ b/examples/megatron/configs/llama2/online_dpo_vllm.yaml @@ -60,3 +60,4 @@ runtime: output_dir: ${output_dir} exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml index 9177fe87..4ab36bba 100644 --- a/examples/megatron/configs/llama2/vllm_param_sync.yaml +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -49,3 +49,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_rlhf.yaml b/examples/megatron/configs/llama2/vllm_rlhf.yaml index b57602b3..34253461 100644 --- a/examples/megatron/configs/llama2/vllm_rlhf.yaml +++ b/examples/megatron/configs/llama2/vllm_rlhf.yaml @@ -82,3 +82,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1}