From 8572a2451be519fc5fd6d1350b1cefcc8317cfed Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Tue, 17 Dec 2024 20:49:38 +0800 Subject: [PATCH 01/13] support megatron ckpt for llama. --- chatlearn/models/vllm/hooks/loader.py | 3 +++ chatlearn/utils/vllm_utils.py | 18 ++++++++++-------- .../configs/llama2/vllm_policy_inference.yaml | 3 +++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/loader.py index 4a4d06c5..e35f8077 100644 --- a/chatlearn/models/vllm/hooks/loader.py +++ b/chatlearn/models/vllm/hooks/loader.py @@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models import llama from vllm.model_executor.models import qwen2 from chatlearn.utils.vllm_import_helper import LlamaForCausalLM @@ -87,6 +88,8 @@ def load_model(self, *, model_config, if self.load_config.model_loader_extra_config.get("need_load_ckpt", True): qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict qwen2.Qwen2ForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights model.load_weights(self.load_config.model_loader_extra_config) else: # For accurate performance evaluation, we assign diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index c04880b0..a48bb01a 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -812,6 +812,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Convert. print("Start to convert...") + prefix_name = "model" if is_vllm_v2() else "model.model" # Embeddings print("Converting embeddings") @@ -838,7 +839,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version elif word_embeddings is not None: # After training with megatron, word_embeddings is stored differently word_embeddings = word_embeddings.to(hf_config.torch_dtype) - output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings # Reset the vocab size hf_config.vocab_size = word_embeddings.shape[0] @@ -879,7 +880,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Is it a weight or a bias? weight_or_bias = m.group(3) # The name of the layer. - layer_name = f"model.model.layers.{layer_idx}" + layer_name = f"{prefix_name}.layers.{layer_idx}" params = val.to(hf_config.torch_dtype) @@ -935,7 +936,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version if "final_norm.weight" in params or "final_layernorm.weight" in params: print("Converting final layernorm") final_norm_weight = params["final_norm.weight"] if "final_norm.weight" in params else params["final_layernorm.weight"] - output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) # For LM head, transformers' wants the matrix to weight embeddings. params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight') @@ -943,7 +944,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version assert not params, "weight name of lm_head expect 'model.language_model.output_layer.weight'." elif params is not None: print("Converting LM head") - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype) # It should be done! print("Conversion from Megatron-LM to Transformers is done!") @@ -993,6 +994,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Convert. print("Start to convert...") + prefix_name = "model" if is_vllm_v2() else "model.model" # Embeddings print("Converting embeddings") @@ -1025,7 +1027,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Convert and store the word embeddings. word_embeddings = tp_state_dicts[tp_rank]['model'].get("embedding.word_embeddings.weight", None) word_embeddings = word_embeddings.to(hf_config.torch_dtype) - output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings # Reset the vocab size hf_config.vocab_size = word_embeddings.shape[0] @@ -1051,7 +1053,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Is it a weight or a bias? weight_or_bias = layer_match_res.group(3) # The name of the layer - layer_name = f"model.model.layers.{layer_idx}" + layer_name = f"{prefix_name}.layers.{layer_idx}" params = val.to(hf_config.torch_dtype) @@ -1112,12 +1114,12 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # The final layernorm. print("Converting final layernorm") final_norm_weight = tp_state_dicts[0]['model'].get("decoder.final_layernorm.weight", None) - output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) # For LM head, transformers' wants the matrix to weight embeddings. print("Converting LM head") params = tp_state_dicts[tp_rank]['model'].get('output_layer.weight', None) - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype) # It should be done! print("Conversion from Megatron-Core to Transformers is done!") diff --git a/examples/megatron/configs/llama2/vllm_policy_inference.yaml b/examples/megatron/configs/llama2/vllm_policy_inference.yaml index 1ad565fa..c1f3f4c0 100644 --- a/examples/megatron/configs/llama2/vllm_policy_inference.yaml +++ b/examples/megatron/configs/llama2/vllm_policy_inference.yaml @@ -45,3 +45,6 @@ vllm_prompt_key: prompt tensor_model_parallel_size: ${policy_tp} pipeline_model_parallel_size: ${policy_pp} + +enforce_eager: ${enforce_eager:False} +vllm_load_format: ${vllm_load_format:dummy} From deb85d9819feb876fa933d1e1380cb9224498593 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 19 Dec 2024 07:11:47 +0000 Subject: [PATCH 02/13] fix pylint. --- chatlearn/models/vllm/hooks/__init__.py | 2 + .../models/vllm/hooks/async_llm_engine.py | 55 ++++++++ chatlearn/models/vllm/hooks/llm.py | 89 ++++++++++++ chatlearn/models/vllm_module_v2.py | 132 +++++++----------- chatlearn/runtime/decorator.py | 4 +- 5 files changed, 200 insertions(+), 82 deletions(-) create mode 100644 chatlearn/models/vllm/hooks/async_llm_engine.py create mode 100644 chatlearn/models/vllm/hooks/llm.py diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 2a4036bc..9bedfe50 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -25,6 +25,8 @@ from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: from chatlearn.models.vllm.hooks import input_preprocess + from chatlearn.models.vllm.hooks import async_llm_engine + from chatlearn.models.vllm.hooks import llm from chatlearn.models.vllm.hooks import loader else: if importlib.util.find_spec("vllm"): diff --git a/chatlearn/models/vllm/hooks/async_llm_engine.py b/chatlearn/models/vllm/hooks/async_llm_engine.py new file mode 100644 index 00000000..ee84f581 --- /dev/null +++ b/chatlearn/models/vllm/hooks/async_llm_engine.py @@ -0,0 +1,55 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine.""" + +from typing import Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.config import EngineConfig +from vllm.engine import async_llm_engine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.usage.usage_lib import UsageContext + +@classmethod +def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, +) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + print(f"debug aaaaa async_llm_engine.") + if engine_config is None: + engine_config = engine_args.create_engine_config() + + executor_class = cls._get_executor_cls(engine_config) + + # Create the async LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + +async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/llm.py b/chatlearn/models/vllm/hooks/llm.py new file mode 100644 index 00000000..e597e096 --- /dev/null +++ b/chatlearn/models/vllm/hooks/llm.py @@ -0,0 +1,89 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs.""" + +from typing import Any, Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints import llm +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter + +def init( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, +) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + print(f"debug aaaaa llm") + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, + **kwargs, + ) + + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS).engine + self.request_counter = Counter() + +llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 2daf2cc4..34a803da 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -18,6 +18,7 @@ import inspect import os import sys +import json from typing import Optional import torch @@ -25,10 +26,12 @@ from vllm import SamplingParams from vllm.config import EngineConfig from vllm.config import LoadFormat +from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter from vllm.utils import FlexibleArgumentParser from chatlearn.utils.global_vars import set_vllm_actors @@ -46,7 +49,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.engine = None + self.llm_engine = None self.tokenizer = None def setup(self): @@ -56,87 +59,47 @@ def setup(self): tokenizer.tokenizer = tokenizer self.tokenizer = tokenizer - def _init_args(self, args): - # scheduler config - args.max_num_seqs = self.module_args.generation_batch_size - args.max_num_batched_tokens = self.model_args.get("max_num_batched_tokens") - args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1) - - # model config - args.max_seq_len = self.model_args.get("seq_length") - - # logger config - args.disable_log_requests = True - - # load format: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. - args.load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) - if args.load_format == LoadFormat.DUMMY: - args.model_loader_extra_config = self.model_args - self.model_args["need_load_ckpt"] = self.src_parameter_model is None - - # engine config - args.enforce_eager = self.model_args.get("enforce_eager", False) - def setup_vllm(self, workers): # setup vllm engine in rank 0 os.environ['VLLM_HOST_IP'] = self.get_address() set_vllm_actors(workers) - parser = FlexibleArgumentParser() - parser = AsyncEngineArgs.add_cli_args(parser) - backup_sys_argv = sys.argv + dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" - vllm_sys_argv = ["", - f"--model={self.model_args['tokenizer']}", - f"--tensor_parallel_size={self.module_args.tensor_model_parallel_size}", - f"--pipeline_parallel_size={self.module_args.pipeline_model_parallel_size}", - f"--dtype={dtype}", - "--worker_use_ray", - "--disable_custom_all_reduce"] - sys.argv = vllm_sys_argv - args = parser.parse_args() - self._init_args(args) - engine_args = AsyncEngineArgs.from_cli_args(args) - self.engine = self.from_engine_args(engine_args) - - sys.argv = backup_sys_argv - self.tokenizer = self.engine.engine.tokenizer - def from_engine_args( - self, - engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - - if engine_config is None: - engine_config = engine_args.create_engine_config() - - executor_class = AsyncLLMEngine._get_executor_cls(engine_config) - - # Create the async LLM engine. - engine = AsyncLLMEngine( - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - async def generate_one_sample(self, prompt, sampling_param, request_id): - results_generator = self.engine.generate(prompt, sampling_param, request_id) - final_output = None - async for request_output in results_generator: - final_output = request_output - return final_output + load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) + if load_format == LoadFormat.DUMMY: + self.model_args["need_load_ckpt"] = self.src_parameter_model is None + model_loader_extra_config = self.model_args + else: + model_loader_extra_config = None + + self.llm = LLM( + model=self.model_args['tokenizer'], + tokenizer=self.model_args['tokenizer'], + max_seq_len_to_capture=self.model_args.get("seq_length"), + # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. + load_format=load_format, + model_loader_extra_config=model_loader_extra_config, + # parallelism strategy + tensor_parallel_size=self.module_args.tensor_model_parallel_size, + pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, + dtype=dtype, + # scheduling strategy + max_num_seqs=self.module_args.generation_batch_size, + max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), + num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), + gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), + # logger + disable_log_requests=self.model_args.get("disable_log_requests", True), + disable_log_stats=self.model_args.get("disable_log_stats", True), + trust_remote_code=True, + # TODO(jiangle.jl): support non-eager mode. + enforce_eager=True, + disable_custom_all_reduce=True, + distributed_executor_backend="ray") + self.tokenizer = self.llm.llm_engine.tokenizer def _get_sampling_params(self, is_eval): temperature = 0.0 @@ -173,7 +136,7 @@ def _get_sampling_params(self, is_eval): sampling_params.use_beam_search = self.model_args.get("use_beam_search") return sampling_params - def convert_v1_inputs(self, prompts, prompt_token_ids): + def _convert_v1_inputs(self, prompts, prompt_token_ids): num_requests = len(prompts) assert num_requests == len(prompt_token_ids), \ ("The lengths of prompts and prompt_token_ids must be the same.") @@ -193,6 +156,9 @@ def convert_v1_inputs(self, prompts, prompt_token_ids): return inputs + async def generate_all(self, prompts, sampling_params): + pass + async def generate_vllm(self, query, is_eval): prompt_key = self.model_args.get("vllm_prompt_key", "prompt") input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids") @@ -202,6 +168,8 @@ async def generate_vllm(self, query, is_eval): seq_len = self.model_args.get("seq_length") final_outputs = [] tasks = [] + parsed_prompts = [] + sampling_params = [] for i, prompt in enumerate(prompts): request_id = i prompt_token_ids = prompts_token_ids[i] @@ -215,14 +183,18 @@ async def generate_vllm(self, query, is_eval): max_tokens = self.model_args.get("max_new_tokens") assert max_tokens < seq_len, "max_new_tokens must less than seq length." sampling_param.max_tokens = max_tokens - inputs = self.convert_v1_inputs( + item = self._convert_v1_inputs( prompts=[prompt], prompt_token_ids=[prompt_token_ids], )[0] + parsed_prompts.append(item) + sampling_params.append(sampling_param) - task = asyncio.create_task(self.generate_one_sample(inputs, sampling_param, request_id)) - tasks.append(task) - outputs = await asyncio.gather(*tasks) + outputs = self.llm.generate( + parsed_prompts, + sampling_params, + use_tqdm=True, + ) final_outputs = sorted(outputs, key=lambda x: int(x.request_id)) return final_outputs @@ -253,7 +225,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.engine = None + self.llm_engine = None def peak_memory(self): """ diff --git a/chatlearn/runtime/decorator.py b/chatlearn/runtime/decorator.py index 0291ee76..56d89972 100644 --- a/chatlearn/runtime/decorator.py +++ b/chatlearn/runtime/decorator.py @@ -164,7 +164,7 @@ def get_kwarg(key): # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ - or isinstance(self, VLLMModuleV2): + or isinstance(self, VLLMModuleV2): final_results = concat_along_batch(results) else: if 'iteration' in inspect.signature(func).parameters: @@ -176,7 +176,7 @@ def get_kwarg(key): # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ - or isinstance(self, VLLMModuleV2): + or isinstance(self, VLLMModuleV2): final_results = ret else: if 'iteration' in inspect.signature(func).parameters: From afa41b7ffc651dc89db875759575efa07011be1c Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 27 Dec 2024 15:06:13 +0800 Subject: [PATCH 03/13] fix broadcast gpu oom --- chatlearn/models/base_module.py | 38 ++++++++++++++++++++------------- chatlearn/utils/dist_utils.py | 19 +++++++++-------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index b09df943..14d07aad 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -29,7 +29,7 @@ from chatlearn.data.sampler import SingleDataSampler, EpisodeDataSampler from chatlearn.checkpoint.checkpoint_manager import CheckpointManager from chatlearn.utils import future -from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense +from chatlearn.utils.dist_utils import bucket_tensor_generator, coalesced_comm_dense from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage from chatlearn.utils.global_vars import get_args from chatlearn.utils.global_vars import set_global_variables @@ -766,23 +766,31 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ :meta private: """ - tensors = [] - for name, param in self._parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer and self._synchronizer.is_parameter_changed: - tensors.append(self._expert_sync_buffer[name]) - else: - tensors.append(param.data) + def tensor_generator(): + for name, param in self._parameters_to_sync[pipe_stage]: + if self._expert_sync_buffer and name in self._expert_sync_buffer: + yield self._expert_sync_buffer[name] + # move self._expert_sync_buffer[name] to cpu mem to save gpu mem + cpu_expert = self._expert_sync_buffer[name].cpu() + del self._expert_sync_buffer[name] + self._expert_sync_buffer[name] = cpu_expert + else: + yield param.data - assert len(tensors) > 0 - dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) - debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) + bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) + dense_bucket_num = 0 + sparse_bucket_num = 0 tensor_changed = rank != src_rank + for bucket_or_tensor, is_dense in bucket_generator: + if is_dense: + coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + dense_bucket_num += 1 + else: + col.broadcast(param, src_rank, group_name) + sparse_bucket_num += 1 - for bucket in dense_buckets: - coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) - - for param in sparse_bucket: - col.broadcast(param, src_rank, group_name) + debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) + self.empty_cache() def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): """ diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index df682d12..bfaad2d0 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -19,12 +19,12 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -def bucket_tensors(tensors, bucket_size_mb): +def bucket_tensor_generator(tensor_generator, bucket_size_mb): """Group tensors into chunks. We seperate sparse and dense tensor, each containing tensors of same type up to certain byte limit in total size. Args: - tensors (Sequence): A sequence of tensors to be separated into chunks. + tensor_generator (Generator): A generator of tensors to be separated into chunks. size_limit (int): The limit of each chunk in bytes. Return: @@ -33,24 +33,21 @@ def bucket_tensors(tensors, bucket_size_mb): """ size_limit = bucket_size_mb * 1024 * 1024 buf_dict = defaultdict(lambda: [[], 0]) - dense_buckets = [] - sparse_bucket = [] for tensor in tensors: if tensor.is_sparse: - sparse_bucket.append(tensor) + yield tensor, False continue t = tensor.type() size = tensor.numel() * tensor.element_size() buf_and_size = buf_dict[t] if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison - dense_buckets.append(buf_and_size[0]) + yield buf_and_size[0], True buf_and_size = buf_dict[t] = [[], 0] buf_and_size[0].append(tensor) buf_and_size[1] += size for buf, _ in buf_dict.values(): if len(buf) > 0: - dense_buckets.append(buf) - return dense_buckets, sparse_bucket + yield buf, True def bucket_tensors_two_stage_generator(tensor_generator, bucket_size_mb, stage2=False, tensor_changed=False): @@ -128,10 +125,14 @@ def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True): coalesced communication for dense parameters """ flat_tensors = _flatten_dense_tensors(bucket) + del bucket comm_call(flat_tensors, *extra_args) if tensor_changed: + all_buffers = _unflatten_dense_tensors(flat_tensors, bucket) + del flat_tensors for tensor, synced in zip( - bucket, _unflatten_dense_tensors(flat_tensors, bucket)): + bucket, all_buffers + ): tensor.copy_(synced) From 7e3fcd04cc3d2031486b631fd2cfe0687d70de6f Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 27 Dec 2024 15:14:55 +0800 Subject: [PATCH 04/13] fix pylint --- chatlearn/models/base_module.py | 2 +- chatlearn/models/vllm_module_v2.py | 1 + chatlearn/utils/dist_utils.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 1ff98e1d..b2040862 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -790,7 +790,7 @@ def tensor_generator(): coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) dense_bucket_num += 1 else: - col.broadcast(param, src_rank, group_name) + col.broadcast(bucket_or_tensor, src_rank, group_name) sparse_bucket_num += 1 debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 2576b306..67c419ca 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -48,6 +48,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() + self.tokenizer = None self._model = None diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index bfaad2d0..4c14cca7 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -33,7 +33,7 @@ def bucket_tensor_generator(tensor_generator, bucket_size_mb): """ size_limit = bucket_size_mb * 1024 * 1024 buf_dict = defaultdict(lambda: [[], 0]) - for tensor in tensors: + for tensor in tensor_generator: if tensor.is_sparse: yield tensor, False continue From 7c75953b8d1df6ab5ff6ba7856501cbd6b045484 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 27 Dec 2024 15:26:23 +0800 Subject: [PATCH 05/13] fix comment --- chatlearn/utils/dist_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index 4c14cca7..c3e65dc4 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -27,7 +27,7 @@ def bucket_tensor_generator(tensor_generator, bucket_size_mb): tensor_generator (Generator): A generator of tensors to be separated into chunks. size_limit (int): The limit of each chunk in bytes. - Return: + Yield: dense_buckets: Blocks of tensors of same type and within size_limit. sparse_bucket: A list of sparse tensors """ From 44b55445149f61d140daedbce6b0469128468dd4 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 27 Dec 2024 23:14:43 +0800 Subject: [PATCH 06/13] fix generator call --- chatlearn/utils/dist_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index c3e65dc4..e5976a3d 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -33,7 +33,7 @@ def bucket_tensor_generator(tensor_generator, bucket_size_mb): """ size_limit = bucket_size_mb * 1024 * 1024 buf_dict = defaultdict(lambda: [[], 0]) - for tensor in tensor_generator: + for tensor in tensor_generator(): if tensor.is_sparse: yield tensor, False continue From 3c7dbeaf7b0fa13334939f6283045fbd18e1a613 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Sat, 28 Dec 2024 17:48:09 +0800 Subject: [PATCH 07/13] fix ut --- chatlearn/utils/dist_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index e5976a3d..507b27d6 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -125,7 +125,6 @@ def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True): coalesced communication for dense parameters """ flat_tensors = _flatten_dense_tensors(bucket) - del bucket comm_call(flat_tensors, *extra_args) if tensor_changed: all_buffers = _unflatten_dense_tensors(flat_tensors, bucket) From a9e4cd6d74e8d9dc0b0e6fa81279c44872677ec7 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 3 Jan 2025 13:40:37 +0800 Subject: [PATCH 08/13] add level option for users --- chatlearn/models/base_module.py | 70 ++++++++++++++++++++--------- chatlearn/runtime/engine.py | 3 ++ chatlearn/schedule/model_manager.py | 7 +++ chatlearn/utils/arguments.py | 2 + chatlearn/utils/dist_utils.py | 31 +++++++++++++ 5 files changed, 91 insertions(+), 22 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index b2040862..a09afd8a 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -25,11 +25,12 @@ from ray.util.collective.collective_group.nccl_collective_group import NCCLGroup from torch.utils.data import DataLoader from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.cuda import max_memory_allocated from chatlearn.data.sampler import SingleDataSampler, EpisodeDataSampler from chatlearn.checkpoint.checkpoint_manager import CheckpointManager from chatlearn.utils import future -from chatlearn.utils.dist_utils import bucket_tensor_generator, coalesced_comm_dense +from chatlearn.utils.dist_utils import bucket_tensors, bucket_tensor_generator, coalesced_comm_dense from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage from chatlearn.utils.global_vars import get_args from chatlearn.utils.global_vars import set_global_variables @@ -770,31 +771,56 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ :meta private: """ - def tensor_generator(): + if self.runtime_args.sync_memory_optimization_level == 0: + log_rank_0(">>>>>>>>>>>>>>>>optimization level 0") + tensors = [] for name, param in self._parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer: - yield self._expert_sync_buffer[name] - # move self._expert_sync_buffer[name] to cpu mem to save gpu mem - cpu_expert = self._expert_sync_buffer[name].cpu() - del self._expert_sync_buffer[name] - self._expert_sync_buffer[name] = cpu_expert + if self._expert_sync_buffer and name in self._expert_sync_buffer and \ + (self._synchronizer and self._synchronizer.is_parameter_changed): + tensors.append(self._expert_sync_buffer[name]) else: - yield param.data + tensors.append(param.data) - bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) - dense_bucket_num = 0 - sparse_bucket_num = 0 - tensor_changed = rank != src_rank - for bucket_or_tensor, is_dense in bucket_generator: - if is_dense: - coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) - dense_bucket_num += 1 - else: - col.broadcast(bucket_or_tensor, src_rank, group_name) - sparse_bucket_num += 1 + assert len(tensors) > 0 + dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) + debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) + tensor_changed = rank != src_rank - debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) - self.empty_cache() + for bucket in dense_buckets: + coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + + for param in sparse_bucket: + col.broadcast(param, src_rank, group_name) + log_rank_0(f"memory footprint peak: {max_memory_allocated() / 1024 ** 3}") + self.empty_cache() + else: + log_rank_0(">>>>>>>>>>>>>>optimization level 1") + def tensor_generator(): + for name, param in self._parameters_to_sync[pipe_stage]: + if self._expert_sync_buffer and name in self._expert_sync_buffer: + yield self._expert_sync_buffer[name] + # move self._expert_sync_buffer[name] to cpu mem to save gpu mem + cpu_expert = self._expert_sync_buffer[name].cpu() + del self._expert_sync_buffer[name] + self._expert_sync_buffer[name] = cpu_expert + else: + yield param.data + + bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) + dense_bucket_num = 0 + sparse_bucket_num = 0 + tensor_changed = rank != src_rank + for bucket_or_tensor, is_dense in bucket_generator: + if is_dense: + coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + dense_bucket_num += 1 + else: + col.broadcast(bucket_or_tensor, src_rank, group_name) + sparse_bucket_num += 1 + + debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) + log_rank_0(f"memory footprint peak: {max_memory_allocated() / 1024 ** 3}") + self.empty_cache() def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): """ diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 2a600702..bf6fdf95 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -293,7 +293,10 @@ def learn(self): self.runtime_args.max_relay_episode, self.runtime_args.relay_episode_offset) logger.info(f"{LOG_START} " + get_full_proc_memory_info('Before first param sync')) + self.timers("sync_parameters").start() self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync) + self.timers("sync_parameters").stop() + logger.info(f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])}") logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync')) self._data_loader = data_loader for episode_id in range(self._start_episode, self.runtime_args.num_episode): diff --git a/chatlearn/schedule/model_manager.py b/chatlearn/schedule/model_manager.py index 313d9bf9..707ca394 100644 --- a/chatlearn/schedule/model_manager.py +++ b/chatlearn/schedule/model_manager.py @@ -17,6 +17,7 @@ import concurrent.futures from collections import defaultdict from concurrent.futures import ThreadPoolExecutor +import time import ray import ray.experimental.state.api @@ -174,18 +175,24 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False): episode_offset % sync_group.frequency == 0: sync_group: ParameterSyncGroup = sync_group + start = time.perf_counter() src_model, dst_model = sync_group.src_model, sync_group.dst_model refs = src_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) future.wait(refs) refs = dst_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) future.wait(refs) + logger.info(f"============In sync_parameters, onload {sync_group} elapsed {time.perf_counter() - start} s") + start = time.perf_counter() sync_group.sync(requires_grad, validate) + logger.info(f"============In sync_parameters, synchronizing {sync_group} elapsed {time.perf_counter() - start} s") + start = time.perf_counter() refs = src_model.offload() future.wait(refs) refs = dst_model.offload() future.wait(refs) + logger.info(f"============In sync_parameters, offload {sync_group} elapsed {time.perf_counter() - start} s") def set_func_decorator(self, model): if is_decorated(model.name): diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index c7d89e50..277f31e4 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -312,6 +312,8 @@ class RuntimeConfig(BaseConfig): param_sync_max_workers: int = None #: communication type to regroup routed experts, allgather/alltoall routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL + #: memory optimization level in synchronization to decide whether save gpu memory or persue faster execution runtime, 0/1 + sync_memory_optimization_level: int = 0 #: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes #: if `max_relay_episode` is set to 0, then relay is disabled max_relay_episode: int = 0 diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index 507b27d6..1a79f9fb 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -18,6 +18,37 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +def bucket_tensors(tensors, bucket_size_mb): + """Group tensors into chunks. We seperate sparse and dense tensor, + each containing tensors of same type up to certain byte limit in total size. + Args: + tensors (Sequence): A sequence of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + Return: + dense_buckets: Blocks of tensors of same type and within size_limit. + sparse_bucket: A list of sparse tensors + """ + size_limit = bucket_size_mb * 1024 * 1024 + buf_dict = defaultdict(lambda: [[], 0]) + dense_buckets = [] + sparse_bucket = [] + for tensor in tensors: + if tensor.is_sparse: + sparse_bucket.append(tensor) + continue + t = tensor.type() + size = tensor.numel() * tensor.element_size() + buf_and_size = buf_dict[t] + if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison + dense_buckets.append(buf_and_size[0]) + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append(tensor) + buf_and_size[1] += size + for buf, _ in buf_dict.values(): + if len(buf) > 0: + dense_buckets.append(buf) + return dense_buckets, sparse_bucket + def bucket_tensor_generator(tensor_generator, bucket_size_mb): """Group tensors into chunks. We seperate sparse and dense tensor, From fb98d4d4c7e695222be4dfa814cbfe7e56c6721d Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 3 Jan 2025 08:18:11 +0000 Subject: [PATCH 09/13] add option for sync memory opt level --- chatlearn/models/base_module.py | 209 ++++++++++++++++++++++---------- chatlearn/utils/arguments.py | 3 +- chatlearn/utils/dist_utils.py | 37 ++++++ 3 files changed, 182 insertions(+), 67 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index a09afd8a..af594cc1 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -31,7 +31,7 @@ from chatlearn.checkpoint.checkpoint_manager import CheckpointManager from chatlearn.utils import future from chatlearn.utils.dist_utils import bucket_tensors, bucket_tensor_generator, coalesced_comm_dense -from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage +from chatlearn.utils.dist_utils import bucket_tensors_two_stage, bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage from chatlearn.utils.global_vars import get_args from chatlearn.utils.global_vars import set_global_variables from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger @@ -767,86 +767,133 @@ def allgather_routed_expert_parameter(self, group_name, pipe_stage=0): self._expert_sync_buffer.pop(name, "Not Found.") self._expert_sync_buffer[name] = param + def _broadcast_parameter_opt_level_0(self, rank, src_rank, group_name, pipe_stage=0): + debug_rank_0(">>>>>>>>>>>>>>>>broadcast parameter at memory optimization level 0") + tensors = [] + for name, param in self._parameters_to_sync[pipe_stage]: + if self._expert_sync_buffer and name in self._expert_sync_buffer and \ + (self._synchronizer and self._synchronizer.is_parameter_changed): + tensors.append(self._expert_sync_buffer[name]) + else: + tensors.append(param.data) + + assert len(tensors) > 0 + dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) + debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) + tensor_changed = rank != src_rank + + for bucket in dense_buckets: + coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + + for param in sparse_bucket: + col.broadcast(param, src_rank, group_name) + self.empty_cache() + + def _broadcast_parameter_opt_level_1(self, rank, src_rank, group_name, pipe_stage=0): + debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1") + def tensor_generator(): + for name, param in self._parameters_to_sync[pipe_stage]: + if self._expert_sync_buffer and name in self._expert_sync_buffer: + yield self._expert_sync_buffer[name] + # move self._expert_sync_buffer[name] to cpu mem to save gpu mem + cpu_expert = self._expert_sync_buffer[name].cpu() + del self._expert_sync_buffer[name] + self._expert_sync_buffer[name] = cpu_expert + else: + yield param.data + + bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) + dense_bucket_num = 0 + sparse_bucket_num = 0 + tensor_changed = rank != src_rank + for bucket_or_tensor, is_dense in bucket_generator: + if is_dense: + coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + dense_bucket_num += 1 + else: + col.broadcast(bucket_or_tensor, src_rank, group_name) + sparse_bucket_num += 1 + + debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) + self.empty_cache() + def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): """ :meta private: """ if self.runtime_args.sync_memory_optimization_level == 0: - log_rank_0(">>>>>>>>>>>>>>>>optimization level 0") - tensors = [] - for name, param in self._parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer and \ - (self._synchronizer and self._synchronizer.is_parameter_changed): - tensors.append(self._expert_sync_buffer[name]) - else: - tensors.append(param.data) + self._broadcast_parameter_opt_level_0(rank, src_rank, group_name, pipe_stage) + else: + self._broadcast_parameter_opt_level_1(rank, src_rank, group_name, pipe_stage) - assert len(tensors) > 0 - dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) - debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger) - tensor_changed = rank != src_rank + def _broadcast_parameter_two_stage_opt_level_0(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 0") + tensor_changed = rank != src_rank - for bucket in dense_buckets: - coalesced_comm_dense(bucket, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) + if stage2: + if tensor_changed: + parameters_to_sync = self._parameters_to_recv[to_rank] + else: + parameters_to_sync = self._parameters_to_send + else: + del self._sync_buffer + self._sync_buffer = defaultdict(list) + parameters_to_sync = self._parameters_to_sync - for param in sparse_bucket: - col.broadcast(param, src_rank, group_name) - log_rank_0(f"memory footprint peak: {max_memory_allocated() / 1024 ** 3}") - self.empty_cache() + tensors = [] + buffer_num = [] + if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks + idx = 0 + for name, param in parameters_to_sync[pipe_stage]: + self._logger.debug( + f"Adding {name} to sync for if branch from " + f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}" + ) + tensors.append(self._sync_buffer[buffer_rank % self.tp_num_mapping][idx]) + buffer_num.append(1) + idx += 1 + del self._sync_buffer[buffer_rank % self.tp_num_mapping] else: - log_rank_0(">>>>>>>>>>>>>>optimization level 1") - def tensor_generator(): - for name, param in self._parameters_to_sync[pipe_stage]: - if self._expert_sync_buffer and name in self._expert_sync_buffer: - yield self._expert_sync_buffer[name] - # move self._expert_sync_buffer[name] to cpu mem to save gpu mem - cpu_expert = self._expert_sync_buffer[name].cpu() - del self._expert_sync_buffer[name] - self._expert_sync_buffer[name] = cpu_expert - else: - yield param.data - - bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb) - dense_bucket_num = 0 - sparse_bucket_num = 0 - tensor_changed = rank != src_rank - for bucket_or_tensor, is_dense in bucket_generator: - if is_dense: - coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed) - dense_bucket_num += 1 + for name, param in parameters_to_sync[pipe_stage]: + self._logger.debug( + f"Adding {name} to sync for else branch from " + f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}" + ) + param_data = param.data + if rank and self._buffer_num and not stage2: + assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}" + buffer_num.append(self._buffer_num[name]) + elif stage2: + buffer_num.append(1) else: - col.broadcast(bucket_or_tensor, src_rank, group_name) - sparse_bucket_num += 1 + # regroup src_tensor by tp_rank. + param_data = self._synchronizer.regroup_params_to_sync(name, param_data, self._tp_division[name]) + buffer_num.append(1) + tensors.append(param_data) - debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger) - log_rank_0(f"memory footprint peak: {max_memory_allocated() / 1024 ** 3}") - self.empty_cache() + assert len(tensors) > 0 + dense_buckets, sparse_bucket = bucket_tensors_two_stage( + tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb, + buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2) + debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, sparse_bucket {len(sparse_bucket)}", self._logger) - def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): - """ - Arguments: - to_rank: receive rank in mapping from trainer to inference model. - buffer_rank: index which tensors of sync buffer to be sended in stage2. - rank: destination rank in communication group which enumerate receive ranks. - src_rank: source rank in communication group. always 0. - group_name: communication group name. - pipe_stage: pipeline stage. default 0. - stage2: bool. whether stage2 or not. default False. - Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1 - stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)] - stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)] + for bucket in dense_buckets: + index = 0 if stage2 else (to_rank % self.tp_num_mapping) + all_buffers = coalesced_comm_dense_two_stage( + bucket, col.broadcast, rank, + extra_args=(src_rank, group_name), tensor_changed=tensor_changed, + stage2=stage2, index=index) + if tensor_changed and not stage2: + for key, value in all_buffers.items(): + self._sync_buffer[key] += value - For stage1 pair (0, 8): - 1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0. - 2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1. + for param in sparse_bucket: + col.broadcast(param, src_rank, group_name) - After (0, 8), to_rank 8 received tensor slices of 8 and 9. + self.empty_cache() - For stage2 pair (8, 9): - 1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0. - 2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1. - In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer. - """ + def _broadcast_parameter_two_stage_opt_level_1(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1") tensor_changed = rank != src_rank if stage2: @@ -937,6 +984,36 @@ def tensor_generator(): self.empty_cache() + def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + """ + Arguments: + to_rank: receive rank in mapping from trainer to inference model. + buffer_rank: index which tensors of sync buffer to be sended in stage2. + rank: destination rank in communication group which enumerate receive ranks. + src_rank: source rank in communication group. always 0. + group_name: communication group name. + pipe_stage: pipeline stage. default 0. + stage2: bool. whether stage2 or not. default False. + Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1 + stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)] + stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)] + + For stage1 pair (0, 8): + 1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0. + 2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1. + + After (0, 8), to_rank 8 received tensor slices of 8 and 9. + + For stage2 pair (8, 9): + 1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0. + 2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1. + In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer. + """ + if self.runtime_args.sync_memory_optimization_level == 0: + self._broadcast_parameter_two_stage_opt_level_0(to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage, stage2) + else: + self._broadcast_parameter_two_stage_opt_level_1(to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage, stage2) + def send_parameter(self, dst_rank, group_name, pipe_stage=0): """ :meta private: diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index 277f31e4..4e5fe5bb 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -313,7 +313,7 @@ class RuntimeConfig(BaseConfig): #: communication type to regroup routed experts, allgather/alltoall routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL #: memory optimization level in synchronization to decide whether save gpu memory or persue faster execution runtime, 0/1 - sync_memory_optimization_level: int = 0 + sync_memory_optimization_level: int = 1 #: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes #: if `max_relay_episode` is set to 0, then relay is disabled max_relay_episode: int = 0 @@ -512,6 +512,7 @@ def _validate_params(self): assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"] assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY] assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE) + assert self.runtime_args.sync_memory_optimization_level in [0, 1] for model_name, model_args in self.models.items(): if model_args.num_gpu >= 1: if model_args.gpu_per_process is None: diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index 1a79f9fb..ed851288 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -81,6 +81,43 @@ def bucket_tensor_generator(tensor_generator, bucket_size_mb): yield buf, True +def bucket_tensors_two_stage(tensors, bucket_size_mb, buffer_num=None, tensor_changed=False): + """Group tensors into chunks. We seperate sparse and dense tensor, + each containing tensors of same type up to certain byte limit in total size. + Args: + tensors (Sequence): A sequence of tensors to be separated into chunks. + size_limit (int): The limit of each chunk in bytes. + Return: + dense_buckets: Blocks of tensors of same type and within size_limit. + sparse_bucket: A list of sparse tensors + """ + size_limit = bucket_size_mb * 1024 * 1024 + buf_dict = defaultdict(lambda: [[], 0]) + dense_buckets = [] + sparse_bucket = [] + for idx, tensor in enumerate(tensors): + buffer_multiple = 1 if buffer_num is None else buffer_num[idx] + if tensor.is_sparse: + sparse_bucket.append(tensor) + continue + t = tensor.type() + # expand buffer size of dst ranks which recv tensor from trainer. + size = tensor.numel() * tensor.element_size() * buffer_multiple + buf_and_size = buf_dict[t] + if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison + dense_buckets.append(buf_and_size[0]) + buf_and_size = buf_dict[t] = [[], 0] + buf_and_size[0].append((torch.empty(size=[tensor.numel() * buffer_multiple], + dtype=tensor.dtype, + device=tensor.device) if (tensor_changed and buffer_multiple > 1) else tensor, + [size // tensor.element_size(), buffer_multiple, tensor])) + buf_and_size[1] += size + for buf, size in buf_dict.values(): + if len(buf) > 0: + dense_buckets.append(buf) + return dense_buckets, sparse_bucket + + def bucket_tensors_two_stage_generator(tensor_generator, bucket_size_mb, stage2=False, tensor_changed=False): """Group tensors into chunks. We seperate sparse and dense tensor, each containing tensors of same type up to certain byte limit in total size. From 051db768a4ff7f4d2d4c0da1ba530e040d307f7f Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 3 Jan 2025 08:20:16 +0000 Subject: [PATCH 10/13] fix pylint --- chatlearn/models/base_module.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index af594cc1..ed21b5d9 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -25,7 +25,6 @@ from ray.util.collective.collective_group.nccl_collective_group import NCCLGroup from torch.utils.data import DataLoader from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors -from torch.cuda import max_memory_allocated from chatlearn.data.sampler import SingleDataSampler, EpisodeDataSampler from chatlearn.checkpoint.checkpoint_manager import CheckpointManager @@ -826,7 +825,7 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0): else: self._broadcast_parameter_opt_level_1(rank, src_rank, group_name, pipe_stage) - def _broadcast_parameter_two_stage_opt_level_0(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + def _broadcast_parameter_two_stage_opt_level_0(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 0") tensor_changed = rank != src_rank @@ -892,7 +891,7 @@ def _broadcast_parameter_two_stage_opt_level_0(self, to_rank, buffer_rank, rank, self.empty_cache() - def _broadcast_parameter_two_stage_opt_level_1(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): + def _broadcast_parameter_two_stage_opt_level_1(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False): debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1") tensor_changed = rank != src_rank From 6c8ebeb8dacd480a7dc742ba7ed1c2a71afa7bbd Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 3 Jan 2025 08:21:43 +0000 Subject: [PATCH 11/13] rm unnecessary timer msg --- chatlearn/schedule/model_manager.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/chatlearn/schedule/model_manager.py b/chatlearn/schedule/model_manager.py index 707ca394..313d9bf9 100644 --- a/chatlearn/schedule/model_manager.py +++ b/chatlearn/schedule/model_manager.py @@ -17,7 +17,6 @@ import concurrent.futures from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -import time import ray import ray.experimental.state.api @@ -175,24 +174,18 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False): episode_offset % sync_group.frequency == 0: sync_group: ParameterSyncGroup = sync_group - start = time.perf_counter() src_model, dst_model = sync_group.src_model, sync_group.dst_model refs = src_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) future.wait(refs) refs = dst_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) future.wait(refs) - logger.info(f"============In sync_parameters, onload {sync_group} elapsed {time.perf_counter() - start} s") - start = time.perf_counter() sync_group.sync(requires_grad, validate) - logger.info(f"============In sync_parameters, synchronizing {sync_group} elapsed {time.perf_counter() - start} s") - start = time.perf_counter() refs = src_model.offload() future.wait(refs) refs = dst_model.offload() future.wait(refs) - logger.info(f"============In sync_parameters, offload {sync_group} elapsed {time.perf_counter() - start} s") def set_func_decorator(self, model): if is_decorated(model.name): From 4526338549774728af40c9b07bab27a3461c3bff Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Fri, 3 Jan 2025 17:06:23 +0800 Subject: [PATCH 12/13] fix empty lines --- chatlearn/utils/dist_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/chatlearn/utils/dist_utils.py b/chatlearn/utils/dist_utils.py index ed851288..d78972cd 100644 --- a/chatlearn/utils/dist_utils.py +++ b/chatlearn/utils/dist_utils.py @@ -18,12 +18,15 @@ import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + def bucket_tensors(tensors, bucket_size_mb): """Group tensors into chunks. We seperate sparse and dense tensor, each containing tensors of same type up to certain byte limit in total size. + Args: tensors (Sequence): A sequence of tensors to be separated into chunks. size_limit (int): The limit of each chunk in bytes. + Return: dense_buckets: Blocks of tensors of same type and within size_limit. sparse_bucket: A list of sparse tensors From d8613520a23d40fed7abd449d5da770e87d0e728 Mon Sep 17 00:00:00 2001 From: Hao Lin Date: Mon, 6 Jan 2025 07:54:56 +0000 Subject: [PATCH 13/13] Add sync_memory_optimization_level for yamls --- chatlearn/runtime/engine.py | 4 ++-- examples/megatron/configs/llama2/grpo_math_vllm.yaml | 1 + examples/megatron/configs/llama2/online_dpo_vllm.yaml | 1 + examples/megatron/configs/llama2/vllm_param_sync.yaml | 1 + examples/megatron/configs/llama2/vllm_rlhf.yaml | 1 + 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/chatlearn/runtime/engine.py b/chatlearn/runtime/engine.py index 330c9c7f..3195d03b 100644 --- a/chatlearn/runtime/engine.py +++ b/chatlearn/runtime/engine.py @@ -296,9 +296,9 @@ def learn(self): self.timers("sync_parameters").start() self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync) self.timers("sync_parameters").stop() + logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync')) logger.info( - f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \ - + get_full_proc_memory_info('After first param sync') + f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " ) self._data_loader = data_loader for episode_id in range(self._start_episode, self.runtime_args.num_episode): diff --git a/examples/megatron/configs/llama2/grpo_math_vllm.yaml b/examples/megatron/configs/llama2/grpo_math_vllm.yaml index 64f54bfd..2a7b933b 100644 --- a/examples/megatron/configs/llama2/grpo_math_vllm.yaml +++ b/examples/megatron/configs/llama2/grpo_math_vllm.yaml @@ -66,3 +66,4 @@ runtime: max_relay_episode: 1 exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/online_dpo_vllm.yaml b/examples/megatron/configs/llama2/online_dpo_vllm.yaml index e64f78ee..f274ee6b 100644 --- a/examples/megatron/configs/llama2/online_dpo_vllm.yaml +++ b/examples/megatron/configs/llama2/online_dpo_vllm.yaml @@ -60,3 +60,4 @@ runtime: output_dir: ${output_dir} exp_name: ${exp_name:chatlearn} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_param_sync.yaml b/examples/megatron/configs/llama2/vllm_param_sync.yaml index 9177fe87..4ab36bba 100644 --- a/examples/megatron/configs/llama2/vllm_param_sync.yaml +++ b/examples/megatron/configs/llama2/vllm_param_sync.yaml @@ -49,3 +49,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1} diff --git a/examples/megatron/configs/llama2/vllm_rlhf.yaml b/examples/megatron/configs/llama2/vllm_rlhf.yaml index b57602b3..34253461 100644 --- a/examples/megatron/configs/llama2/vllm_rlhf.yaml +++ b/examples/megatron/configs/llama2/vllm_rlhf.yaml @@ -82,3 +82,4 @@ runtime: exp_name: ${exp_name:chatlearn} debug: ${debug:False} validate_param_sync: ${validate_param_sync:False} + sync_memory_optimization_level: ${sync_memory_optimization_level:1}