From 6cb15568fd922d21d1433a21a9f3a0e41f21d5d0 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Wed, 22 Jan 2025 15:27:01 +0800 Subject: [PATCH 01/19] upgrade to vllm0.6.6 --- chatlearn/models/vllm/hooks/__init__.py | 1 + .../models/vllm/hooks/async_llm_engine.py | 51 +-- .../models/vllm/hooks/input_preprocess.py | 56 ++- chatlearn/models/vllm/hooks/llm.py | 140 +++--- chatlearn/models/vllm/hooks/llm_engine.py | 29 ++ chatlearn/models/vllm/hooks/loader.py | 72 ++-- .../models/vllm/hooks/ray_gpu_executor.py | 405 ++++++++++-------- chatlearn/models/vllm/hooks/worker_base.py | 1 + chatlearn/models/vllm_module_v2.py | 156 ++++++- chatlearn/runtime/dist_actor.py | 11 +- chatlearn/utils/constant.py | 3 +- chatlearn/utils/vllm_import_helper.py | 5 +- .../megatron/models/vllm_policy_inference.py | 1 + 13 files changed, 608 insertions(+), 323 deletions(-) diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 08856241..40e156eb 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -27,6 +27,7 @@ from chatlearn.models.vllm.hooks import input_preprocess from chatlearn.models.vllm.hooks import async_llm_engine from chatlearn.models.vllm.hooks import llm + from chatlearn.models.vllm.hooks import llm_engine from chatlearn.models.vllm.hooks import loader from chatlearn.models.vllm.hooks import worker_base else: diff --git a/chatlearn/models/vllm/hooks/async_llm_engine.py b/chatlearn/models/vllm/hooks/async_llm_engine.py index 45428241..c02699e9 100644 --- a/chatlearn/models/vllm/hooks/async_llm_engine.py +++ b/chatlearn/models/vllm/hooks/async_llm_engine.py @@ -17,7 +17,7 @@ from typing import Dict, Optional # pylint: disable=unused-import,wildcard-import,unused-argument,not-callable -from vllm.config import EngineConfig +from vllm.config import VllmConfig from vllm.engine import async_llm_engine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.metrics_types import StatLoggerBase @@ -25,30 +25,33 @@ @classmethod def from_engine_args( - cls, - engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, -) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - if engine_config is None: - engine_config = engine_args.create_engine_config() + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[VllmConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + if engine_config is None: + engine_config = engine_args.create_engine_config(usage_context) - executor_class = cls._get_executor_cls(engine_config) + executor_class = cls._get_executor_cls(engine_config) - # Create the async LLM engine. - engine = cls( - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine + # if executor_class.uses_ray: + # initialize_ray_cluster(engine_config.parallel_config) + + # Create the async LLM engine. + engine = cls( + vllm_config=engine_config, + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/input_preprocess.py b/chatlearn/models/vllm/hooks/input_preprocess.py index 0f9f4ae0..2feba476 100644 --- a/chatlearn/models/vllm/hooks/input_preprocess.py +++ b/chatlearn/models/vllm/hooks/input_preprocess.py @@ -19,13 +19,61 @@ # pylint: disable=unused-import,unused-argument from vllm.inputs import preprocess +from vllm.inputs.data import token_inputs - -source = inspect.getsource(preprocess.InputPreprocessor._extract_prompt_components) +source = inspect.getsource(preprocess.InputPreprocessor._prompt_to_llm_inputs) if 'parsed = parse_singleton_prompt(prompt)' in source: from vllm.inputs.parse import parse_singleton_prompt - def extract_prompt_components( + + def _prompt_to_llm_inputs( + self, + prompt, + request_id: str, + lora_request=None, + ): + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * request_id + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * :class:`SingletonInputs` instance + """ + parsed = parse_singleton_prompt(prompt) + + assert parsed["type"] == "tokens", \ + f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + + return token_inputs( + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + + def _prompt_to_llm_inputs_( self, prompt, request_id, @@ -59,4 +107,4 @@ def extract_prompt_components( return (prompt_text, prompt_token_ids, multi_modal_data, mm_processor_kwargs) - preprocess.InputPreprocessor._extract_prompt_components = extract_prompt_components + preprocess.InputPreprocessor._prompt_to_llm_inputs = _prompt_to_llm_inputs diff --git a/chatlearn/models/vllm/hooks/llm.py b/chatlearn/models/vllm/hooks/llm.py index 8824e940..a613d608 100644 --- a/chatlearn/models/vllm/hooks/llm.py +++ b/chatlearn/models/vllm/hooks/llm.py @@ -14,74 +14,98 @@ # ============================================================================== """Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs.""" -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union # pylint: disable=unused-import,wildcard-import,unused-argument -from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, HfOverrides, TaskOption, PoolerConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints import llm from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter def init( - self, - model: str, - tokenizer: Optional[str] = None, - tokenizer_mode: str = "auto", - skip_tokenizer_init: bool = False, - trust_remote_code: bool = False, - tensor_parallel_size: int = 1, - dtype: str = "auto", - quantization: Optional[str] = None, - revision: Optional[str] = None, - tokenizer_revision: Optional[str] = None, - seed: int = 0, - gpu_memory_utilization: float = 0.9, - swap_space: float = 4, - cpu_offload_gb: float = 0, - enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: int = 8192, - disable_custom_all_reduce: bool = False, - disable_async_output_proc: bool = False, - mm_processor_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, -) -> None: - ''' - LLM constructor. + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "auto", + override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + **kwargs, + ) -> None: + ''' + LLM constructor. - Note: if enforce_eager is unset (enforce_eager is None) - it defaults to False. - ''' - if "disable_log_stats" not in kwargs: - kwargs["disable_log_stats"] = True + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' - engine_args = AsyncEngineArgs( - model=model, - tokenizer=tokenizer, - tokenizer_mode=tokenizer_mode, - skip_tokenizer_init=skip_tokenizer_init, - trust_remote_code=trust_remote_code, - tensor_parallel_size=tensor_parallel_size, - dtype=dtype, - quantization=quantization, - revision=revision, - tokenizer_revision=tokenizer_revision, - seed=seed, - gpu_memory_utilization=gpu_memory_utilization, - swap_space=swap_space, - cpu_offload_gb=cpu_offload_gb, - enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, - max_seq_len_to_capture=max_seq_len_to_capture, - disable_custom_all_reduce=disable_custom_all_reduce, - disable_async_output_proc=disable_async_output_proc, - mm_processor_kwargs=mm_processor_kwargs, - **kwargs, - ) + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if compilation_config is not None: + if isinstance(compilation_config, (int, dict)): + compilation_config_instance = CompilationConfig.from_cli( + str(compilation_config)) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = None + + engine_args = AsyncEngineArgs( + model=model, + task=task, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + # Logic to switch between engines is done at runtime instead of import + # to avoid import order issues + self.engine_class = self.get_engine_class() + + # TODO(rob): enable mp by default (issue with fork vs spawn) + self.llm_engine = self.engine_class.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) - self.llm_engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.LLM_CLASS).engine - self.request_counter = Counter() + self.request_counter = Counter() llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm/hooks/llm_engine.py b/chatlearn/models/vllm/hooks/llm_engine.py index bff1d188..ff2c3b4d 100644 --- a/chatlearn/models/vllm/hooks/llm_engine.py +++ b/chatlearn/models/vllm/hooks/llm_engine.py @@ -15,9 +15,12 @@ """Hooks of vllm-0.5.1 llm_engine remove __reduce__ function.""" import inspect +from typing import Dict, Optional # pylint: disable=unused-import,wildcard-import,unused-argument from vllm.engine import llm_engine +from vllm.engine.metrics_types import StatLoggerBase +from vllm.usage.usage_lib import UsageContext source = inspect.getsource(llm_engine.LLMEngine.__reduce__) @@ -28,3 +31,29 @@ def __reduce__(self): pass del llm_engine.LLMEngine.__reduce__ + + +@classmethod +def from_engine_args( + cls, + engine_args, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, +) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config(usage_context) + # executor_class = cls._get_executor_cls(engine_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + executor_class = RayGPUExecutor + # Create the LLM engine. + engine = cls( + vllm_config=engine_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + return engine +llm_engine.LLMEngine.from_engine_args = from_engine_args \ No newline at end of file diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/loader.py index 147e1b8a..f7274e9b 100644 --- a/chatlearn/models/vllm/hooks/loader.py +++ b/chatlearn/models/vllm/hooks/loader.py @@ -24,6 +24,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.models import llama from vllm.model_executor.models import qwen2, qwen2_moe +from vllm.config import VllmConfig from chatlearn.utils.vllm_import_helper import LlamaForCausalLM from chatlearn.utils.vllm_import_helper import QWenLMHeadModel @@ -75,41 +76,40 @@ def init(self, load_config): # add ckpt loading of megatron format -def load_model(self, *, model_config, - device_config, - lora_config, - parallel_config, - scheduler_config, - cache_config): - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(model_config, self.load_config, - lora_config, cache_config, - scheduler_config) - if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ - self.load_config.model_loader_extra_config["load"] is not None: - qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict - qwen2.Qwen2ForCausalLM.load_weights = load_weights - qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict - qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights - llama.LlamaForCausalLM.load_state_dict = load_state_dict - llama.LlamaForCausalLM.load_weights = load_weights - model.load_weights(self.load_config.model_loader_extra_config) - else: - # For accurate performance evaluation, we assign +def load_model(self, vllm_config: VllmConfig):# -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. - initialize_dummy_weights(model) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context( - module, torch.device(device_config.device)): - quant_method.process_weights_after_loading(module) - return model.eval() + if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ + self.load_config.model_loader_extra_config["load"] is not None: + qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict + qwen2.Qwen2ForCausalLM.load_weights = load_weights + qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict + qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights + model.load_weights(self.load_config.model_loader_extra_config) + else: + # For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) + return model.eval() + + loader.DummyModelLoader.load_model = load_model diff --git a/chatlearn/models/vllm/hooks/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/ray_gpu_executor.py index f7269930..78f02fd2 100644 --- a/chatlearn/models/vllm/hooks/ray_gpu_executor.py +++ b/chatlearn/models/vllm/hooks/ray_gpu_executor.py @@ -22,7 +22,7 @@ from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger from vllm.utils import (get_distributed_init_method, - get_ip, get_open_port, get_vllm_instance_id) + get_ip, get_open_port) from chatlearn.utils.global_vars import get_vllm_actors @@ -31,192 +31,223 @@ # modified based on https://github.com/vllm-project/vllm/blob/6aa6020f9bd4c1e414c10f7bd3a7c2555f1950b2/vllm/executor/ray_gpu_executor.py#L109 def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): # pylint: disable=unused-argument - - # The driver dummy worker does not actually use any resources. - # It holds the resource for the driver worker. - self.driver_dummy_worker: Optional[RayWorkerWrapper] = None - # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] - - # Used in ray compiled DAG: indexed first by PP rank, - # and then TP rank. In other words, the inner list is - # the TP group of workers for a PP rank. - self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] - - if self.parallel_config.ray_workers_use_nsight: - ray_remote_kwargs = self._configure_ray_workers_use_nsight( - ray_remote_kwargs) - - logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) - - # Create the workers. - driver_ip = get_ip() - # driver_actor_id = ray.get_runtime_context().get_actor_id() - vllm_workers = get_vllm_actors() - worker_wrapper_kwargs = self._get_worker_wrapper_args() - if self.use_ray_spmd_worker: - self.workers = vllm_workers - else: - for worker in vllm_workers: - # we cannot call remote func of actor if the actor is its self - worker_ip = ray.get(worker.get_node_ip.remote()) - # if worker._actor_id.hex() == driver_actor_id and self.driver_dummy_worker is None: - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) - else: - # Else, added to the list of workers. - self.workers.append(worker) - logger.debug("workers: %s", self.workers) - logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) - if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") - worker_ips = [ - ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] - for worker in self.workers - ] - ip_counts: Dict[str, int] = {} - for ip in worker_ips: - ip_counts[ip] = ip_counts.get(ip, 0) + 1 - - def sort_by_driver_then_worker_ip(worker): - """ - Sort the workers based on 3 properties: - 1. If the worker is on the same node as the driver (vllm engine), - it should be placed first. - 2. Then, if the worker is on a node with fewer workers, it should - be placed first. - 3. Finally, if the work is on a node with smaller IP address, it - should be placed first. - """ - ip = ray.get(worker.get_node_ip.remote()) - return (ip != driver_ip, ip_counts[ip], ip) - - # After sorting, the workers on the same node will be - # close to each other, and the workers on the driver - # node will be placed first. - self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - - # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", - use_dummy_driver=True) - # worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids") - - node_workers = defaultdict(list) # node id -> list of worker ranks - node_gpus = defaultdict(list) # node id -> list of gpu ids - - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): - node_workers[node_id].append(i) - # `gpu_ids` can be a list of strings or integers. - # convert them to integers for consistency. - # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), - # string sorting is not sufficient. - # see https://github.com/vllm-project/vllm/issues/5590 - gpu_ids = [int(x) for x in gpu_ids] - node_gpus[node_id].extend(gpu_ids) - for node_id, gpu_ids in node_gpus.items(): - node_gpus[node_id] = sorted(gpu_ids) - - all_ips = set(worker_ips + [driver_ip]) - n_ips = len(all_ips) - n_nodes = len(node_workers) - - if n_nodes != n_ips: - raise RuntimeError( - f"Every node should have a unique IP address. Got {n_nodes}" - f" nodes with node ids {list(node_workers.keys())} and " - f"{n_ips} unique IP addresses {all_ips}. Please check your" - " network configuration. If you set `VLLM_HOST_IP` or " - "`HOST_IP` environment variable, make sure it is unique for" - " each node.") - - VLLM_INSTANCE_ID = get_vllm_instance_id() - - # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])), - "VLLM_INSTANCE_ID": - VLLM_INSTANCE_ID, - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - **({ - "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND - } if envs.VLLM_ATTENTION_BACKEND is not None else {}) - },) for (node_id, _) in worker_node_and_gpu_ids] - - self._env_vars_for_all_workers = ( - all_args_to_update_environment_variables) - - self._run_workers("update_environment_variables", - all_args=self._get_env_vars_to_be_updated()) - - if len(node_gpus) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [ - self._get_worker_kwargs( - # local_rank=node_workers[node_id].index(rank), - local_rank=0, - rank=rank, - distributed_init_method=distributed_init_method, - ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) - ] - self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) - - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - - if self.use_ray_spmd_worker: - for pp_rank in range(self.parallel_config.pipeline_parallel_size): - self.pp_tp_workers.append([]) - for tp_rank in range( - self.parallel_config.tensor_parallel_size): - # PP=2, TP=4 - # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = (pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank - assert len(self.pp_tp_workers[pp_rank]) == tp_rank - assert pp_rank < len(self.pp_tp_workers) - self.pp_tp_workers[pp_rank].append(self.workers[rank]) - - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: List[RayWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: List[RayWorkerWrapper] = [] - - # Enforce rank order for correct rank to return final output. - for index, worker in enumerate(self.workers): - # The driver worker is rank 0 and not in self.workers. - rank = index + 1 - if rank % self.parallel_config.tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) + **ray_remote_kwargs): + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): + # For single GPU case, we use a ray worker with constrained memory. + num_gpus = self.cache_config.gpu_memory_utilization else: - self.non_driver_workers.append(worker) - + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + driver_ip = get_ip() + workers = [] + workers = get_vllm_actors() + # if self.use_ray_spmd_worker: + # workers = vllm_workers + # else: + # for bundle_id, bundle in enumerate(placement_group.bundle_specs): + # if not bundle.get("GPU", 0): + # continue + # scheduling_strategy = PlacementGroupSchedulingStrategy( + # placement_group=placement_group, + # placement_group_capture_child_tasks=True, + # placement_group_bundle_index=bundle_id, + # ) + + # worker = ray.remote( + # num_cpus=0, + # num_gpus=num_gpus, + # scheduling_strategy=scheduling_strategy, + # **ray_remote_kwargs, + # )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) + # workers.append(worker) + + worker_ip_refs = [ + worker.get_node_ip.remote() # type: ignore[attr-defined] + for worker in workers + ] + worker_ips = ray.get(worker_ip_refs) + + if not self.use_ray_spmd_worker: + for i in range(len(workers)): + worker = workers[i] + worker_ip = worker_ips[i] + if self.driver_dummy_worker is None and worker_ip == driver_ip: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + vllm_config=self.vllm_config) + workers.pop(i) + worker_ips.pop(i) + self.workers = workers + break + else: + self.workers = workers + + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + worker_to_ip = dict(zip(self.workers, worker_ips)) + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = worker_to_ip[worker] + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" + " each node.") + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=0,#node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) RayGPUExecutor._init_workers_ray = _init_workers_ray diff --git a/chatlearn/models/vllm/hooks/worker_base.py b/chatlearn/models/vllm/hooks/worker_base.py index 658cfc22..2e889f0d 100644 --- a/chatlearn/models/vllm/hooks/worker_base.py +++ b/chatlearn/models/vllm/hooks/worker_base.py @@ -28,6 +28,7 @@ def execute_method(self, method, *args, **kwargs): target = self.worker else: target = self + #print(f"debug target: {target} method: {method}") executor = getattr(target, method) return executor(*args, **kwargs) except Exception as e: diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 072d9c24..248f97a5 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -14,6 +14,7 @@ # ============================================================================== """VLLM module v2""" +import asyncio import gc import inspect import os @@ -22,9 +23,11 @@ from transformers import AutoTokenizer from vllm import SamplingParams from vllm.config import LoadFormat +from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper - +from vllm.usage.usage_lib import UsageContext + from chatlearn.utils.global_vars import set_vllm_actors from chatlearn.utils.vllm_import_helper import parallel_state from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank @@ -49,8 +52,10 @@ def __init__(self, *args, **kwargs): assert common_methods == {'__init__'}, \ f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" self.local_rank = 0 - if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: - RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called + if 'vllm_actor_type' in kwargs and 'worker' == kwargs['vllm_actor_type']: + vllm_config = self.init_engine_args() + RayWorkerWrapper.__init__(self, vllm_config=vllm_config) # pylint: disable=non-parent-init-called + os.environ['VLLM_HOST_IP'] = self.get_address() self.tokenizer = None @@ -75,10 +80,55 @@ def add_extra_args(self, parser): help='Timeout minutes for torch.distributed.') return parser + def init_engine_args(self): + dtype = self.model_args.get("dtype", "bfloat16") + if self.model_args.get("fp16", False): + dtype = "float16" + + load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) + if load_format == LoadFormat.DUMMY: + self.model_args["need_load_ckpt"] = self.src_parameter_model is None + model_loader_extra_config = self.model_args + else: + model_loader_extra_config = None + + if self.model_args.get("apply_replica_id_to_seed", True): + seed = self.model_args.get("seed", 0) + self.replica_id + else: + seed = self.model_args.get("seed", 0) + + self.engine_args = AsyncEngineArgs( + model=self.model_args['tokenizer'], + tokenizer=self.model_args['tokenizer'], + max_seq_len_to_capture=self.model_args.get("seq_length"), + seed=seed, + # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. + load_format=load_format, + model_loader_extra_config=model_loader_extra_config, + # parallelism strategy + tensor_parallel_size=self.module_args.tensor_model_parallel_size, + pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, + dtype=dtype, + # scheduling strategy + max_num_seqs=self.module_args.generation_batch_size, + max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), + num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), + gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), + # logger + disable_log_requests=self.model_args.get("disable_log_requests", True), + disable_log_stats=self.model_args.get("disable_log_stats", True), + trust_remote_code=True, + enforce_eager=self.model_args.get("enforce_eager", False), + disable_custom_all_reduce=True, + distributed_executor_backend="ray") + return self.engine_args.create_engine_config(usage_context=UsageContext.ENGINE_CONTEXT) + + def init(self): """ :meta private: """ + return parallel_state.set_custom_all_reduce(False) initialize_vllm(extra_args_provider=self.add_extra_args, ignore_unknown_args=True, @@ -98,6 +148,10 @@ def setup_vllm(self, workers): os.environ['VLLM_HOST_IP'] = self.get_address() set_vllm_actors(workers) + if self.apply_async: + self.engine = AsyncLLMEngine.from_engine_args(self.engine_args) + return + dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" @@ -256,9 +310,7 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids): return inputs - def generate_vllm(self, query, is_eval, is_first_run=True): - if is_first_run: # using for multi-round generate - self.reinit_cache_engine() + def preprocess_inputs(self, query, is_eval): prompt_key = self.model_args.get("vllm_prompt_key", "prompt") input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids") @@ -286,11 +338,101 @@ def generate_vllm(self, query, is_eval, is_first_run=True): parsed_prompts.append(item) sampling_params.append(sampling_param) + return parsed_prompts, sampling_params + + def run_vllm(self, parsed_prompts, sampling_params): + # breakpoint() outputs = self.llm.generate( parsed_prompts, sampling_params, - use_tqdm=True, + use_tqdm=True + ) + return outputs + + async def run_vllm_async(self, parsed_prompts, sampling_params): + from vllm.utils import merge_async_iterators + + # generators = [] + # start = time.perf_counter() + # for i, (prompt, sp) in enumerate(zip(parsed_prompts, sampling_params)): + # generator = self.llm.llm_engine.engine.generate(prompt, sp, request_id=f"test{i}") + # generators.append(generator) + # all_gens = merge_async_iterators(*generators) + # async for i, res in all_gens: + # pass + # async with build_async_engine_client_from_engine_args( + # self.engine_args, disable_frontend_multiprocessing=False) as llm: + # async self.engine as llm: + + # Add the requests to the engine. + prompts = parsed_prompts + # sampling_params: List[SamplingParams] = [] + # for prompt, _, output_len in requests: + # prompts.append(prompt) + # sampling_params.append( + # SamplingParams( + # n=n, + # temperature=1.0, + # top_p=1.0, + # ignore_eos=True, + # max_tokens=output_len, + # )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = self.llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + print(f"debug all_gens: {all_gens}", flush=True) + end = time.perf_counter() + return await all_gens + + async def run_query(self, prompt, sampling_param): + request_id = uuid4() + + outputs = self.engine.generate( + prompt, + params, + request_id ) + async for output in outputs: + final_output = output + responses = [] + for output in final_output.outputs: + responses.append(output) + return responses + + + async def process(self, parsed_prompts, sampling_params): + tasks = [asyncio.create_task( + self.run_query(prompt, sampling_param)) + for prompt, sampling_param in zip(parsed_prompts, sampling_params)] + results = [] + for task in asyncio.as_completed(tasks): + result = await task + results.append(result) + return results + + + + def generate_vllm(self, query, is_eval, is_first_run=True): + if is_first_run: # using for multi-round generate + self.reinit_cache_engine() + parsed_prompts, sampling_params = self.preprocess_inputs(query, is_eval) + + if self.apply_async: + outputs = asyncio.run(self.process(parsed_prompts, sampling_params)) + + # outputs = uvloop.run(self.run_vllm_async(parsed_prompts, sampling_params)) + # if outputs is None: + # print(f"debug outputs: {outputs}") + # else: + # print(f"debug outputs: {outputs[0]}") + else: + outputs = self.run_vllm(parsed_prompts, sampling_params) return outputs def is_last_rank(self): diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index b82c0283..5ab4e810 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -234,11 +234,14 @@ def __init__(self, *args, **kwargs): self.vllm_engine = None def create_actor(self, num_gpus, placement_group, group_index): + # kwargs = { + # "worker_module_name": "vllm.worker.worker", + # "worker_class_name": "Worker", + # "worker_class_fn": None, + # "trust_remote_code": True, + # } kwargs = { - "worker_module_name": "vllm.worker.worker", - "worker_class_name": "Worker", - "worker_class_fn": None, - "trust_remote_code": True, + "vllm_actor_type" : "worker" } self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) diff --git a/chatlearn/utils/constant.py b/chatlearn/utils/constant.py index 4ee59907..c1dde2fe 100644 --- a/chatlearn/utils/constant.py +++ b/chatlearn/utils/constant.py @@ -38,7 +38,8 @@ class VLLMVersion(str, Enum): """support versions of vLLM.""" v_0_3_0 = "0.3.0" v_0_5_1 = "0.5.1" - v_0_6_3 = "0.6.3" + # v_0_6_3 = "0.6.3" + v_0_6_3 = "0.6.6" class QwenVersion(float, Enum): diff --git a/chatlearn/utils/vllm_import_helper.py b/chatlearn/utils/vllm_import_helper.py index 71864503..ec4ccfc5 100644 --- a/chatlearn/utils/vllm_import_helper.py +++ b/chatlearn/utils/vllm_import_helper.py @@ -47,7 +47,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.detokenizer import Detokenizer -elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: +elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3]: # imports for vllm-063 from vllm.core.interfaces import BlockSpaceManager from vllm.distributed import parallel_state @@ -56,7 +56,8 @@ from vllm.distributed.parallel_state import initialize_model_parallel from vllm.distributed.utils import get_pp_indices from vllm.engine.async_llm_engine import _AsyncLLMEngine as LLMEngine - from vllm.engine.llm_engine import _load_generation_config_dict + # if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + # from vllm.engine.llm_engine import _load_generation_config_dict from vllm.engine.llm_engine import SchedulerContext, SchedulerOutputState from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker diff --git a/examples/megatron/models/vllm_policy_inference.py b/examples/megatron/models/vllm_policy_inference.py index 4c2402de..42b255d1 100644 --- a/examples/megatron/models/vllm_policy_inference.py +++ b/examples/megatron/models/vllm_policy_inference.py @@ -138,6 +138,7 @@ def decode_internal(self, batched_outputs): prompt_sizes = torch.tensor([len(q) for q in no_padded_query_ids], device=all_tokens.device) loss_mask = get_loss_mask(all_tokens, self.tokenizer.tokenizer.eos_token_id, prompt_sizes) loss_mask = loss_mask.to("cpu") + print(f"str_outputs: {len(str_outputs)} {str_outputs}") return {"all_tokens": all_tokens, "str_outputs": str_outputs, "str_prompts": str_prompts, "no_padded_query_ids": no_padded_query_ids, "logprobs": logprobs, "loss_mask": loss_mask} From e35ae1c820e9b845e9d6899d65e0bde0878437f8 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Wed, 22 Jan 2025 17:59:41 +0800 Subject: [PATCH 02/19] remove async. --- chatlearn/models/vllm_module_v2.py | 85 +----------------------------- 1 file changed, 1 insertion(+), 84 deletions(-) diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 248f97a5..18d9e428 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -14,7 +14,6 @@ # ============================================================================== """VLLM module v2""" -import asyncio import gc import inspect import os @@ -148,10 +147,6 @@ def setup_vllm(self, workers): os.environ['VLLM_HOST_IP'] = self.get_address() set_vllm_actors(workers) - if self.apply_async: - self.engine = AsyncLLMEngine.from_engine_args(self.engine_args) - return - dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" @@ -349,90 +344,12 @@ def run_vllm(self, parsed_prompts, sampling_params): ) return outputs - async def run_vllm_async(self, parsed_prompts, sampling_params): - from vllm.utils import merge_async_iterators - - # generators = [] - # start = time.perf_counter() - # for i, (prompt, sp) in enumerate(zip(parsed_prompts, sampling_params)): - # generator = self.llm.llm_engine.engine.generate(prompt, sp, request_id=f"test{i}") - # generators.append(generator) - # all_gens = merge_async_iterators(*generators) - # async for i, res in all_gens: - # pass - # async with build_async_engine_client_from_engine_args( - # self.engine_args, disable_frontend_multiprocessing=False) as llm: - # async self.engine as llm: - - # Add the requests to the engine. - prompts = parsed_prompts - # sampling_params: List[SamplingParams] = [] - # for prompt, _, output_len in requests: - # prompts.append(prompt) - # sampling_params.append( - # SamplingParams( - # n=n, - # temperature=1.0, - # top_p=1.0, - # ignore_eos=True, - # max_tokens=output_len, - # )) - - generators = [] - start = time.perf_counter() - for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): - generator = self.llm.generate(prompt, sp, request_id=f"test{i}") - generators.append(generator) - all_gens = merge_async_iterators(*generators) - async for i, res in all_gens: - pass - print(f"debug all_gens: {all_gens}", flush=True) - end = time.perf_counter() - return await all_gens - - async def run_query(self, prompt, sampling_param): - request_id = uuid4() - - outputs = self.engine.generate( - prompt, - params, - request_id - ) - async for output in outputs: - final_output = output - responses = [] - for output in final_output.outputs: - responses.append(output) - return responses - - - async def process(self, parsed_prompts, sampling_params): - tasks = [asyncio.create_task( - self.run_query(prompt, sampling_param)) - for prompt, sampling_param in zip(parsed_prompts, sampling_params)] - results = [] - for task in asyncio.as_completed(tasks): - result = await task - results.append(result) - return results - - - def generate_vllm(self, query, is_eval, is_first_run=True): if is_first_run: # using for multi-round generate self.reinit_cache_engine() parsed_prompts, sampling_params = self.preprocess_inputs(query, is_eval) - if self.apply_async: - outputs = asyncio.run(self.process(parsed_prompts, sampling_params)) - - # outputs = uvloop.run(self.run_vllm_async(parsed_prompts, sampling_params)) - # if outputs is None: - # print(f"debug outputs: {outputs}") - # else: - # print(f"debug outputs: {outputs[0]}") - else: - outputs = self.run_vllm(parsed_prompts, sampling_params) + outputs = self.run_vllm(parsed_prompts, sampling_params) return outputs def is_last_rank(self): From 020cc4f5daa4c049926b136f3c99c7251b077ee4 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Wed, 22 Jan 2025 20:29:21 +0800 Subject: [PATCH 03/19] fix error monitor. --- chatlearn/models/vllm/hooks/worker_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/chatlearn/models/vllm/hooks/worker_base.py b/chatlearn/models/vllm/hooks/worker_base.py index 2e889f0d..54100726 100644 --- a/chatlearn/models/vllm/hooks/worker_base.py +++ b/chatlearn/models/vllm/hooks/worker_base.py @@ -19,6 +19,8 @@ from vllm.worker.worker_base import logger +del worker_base.WorkerWrapperBase.__getattr__ + def execute_method(self, method, *args, **kwargs): try: if self.worker is None: From d64eff4f8b438e133f67898605005076efbe947d Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 23 Jan 2025 14:17:08 +0800 Subject: [PATCH 04/19] fix e2e. --- chatlearn/models/vllm/hooks/loader.py | 2 +- chatlearn/models/vllm_module_v2.py | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/loader.py index f7274e9b..17af8e37 100644 --- a/chatlearn/models/vllm/hooks/loader.py +++ b/chatlearn/models/vllm/hooks/loader.py @@ -92,7 +92,7 @@ def load_model(self, vllm_config: VllmConfig):# -> nn.Module: qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights llama.LlamaForCausalLM.load_state_dict = load_state_dict llama.LlamaForCausalLM.load_weights = load_weights - model.load_weights(self.load_config.model_loader_extra_config) + #model.load_weights(self.load_config.model_loader_extra_config) else: # For accurate performance evaluation, we assign # random values to the weights. diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 18d9e428..a720b28a 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -127,7 +127,6 @@ def init(self): """ :meta private: """ - return parallel_state.set_custom_all_reduce(False) initialize_vllm(extra_args_provider=self.add_extra_args, ignore_unknown_args=True, @@ -402,6 +401,15 @@ def pipeline_parallel_rank(self): """ return get_pipeline_model_parallel_rank() + def tensor_and_expert_model_parallel_size(self): + """ + get tensor_and_expert_model_parallel_size + :meta private: + """ + # vLLM not supported to enable expert parallel size + # thus: tensor_and_expert_model_parallel_size = tensor_parallel_size + return parallel_state.get_tensor_model_parallel_world_size() + def model_setup_for_workers(self): self.llm.llm_engine.model_executor._run_workers("model_setup") From 8dfd40667eae4657df49f05337fcb29da0a6432c Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 23 Jan 2025 14:58:36 +0800 Subject: [PATCH 05/19] flow prompt. --- .../models/vllm/hooks/input_preprocess.py | 35 +------------------ 1 file changed, 1 insertion(+), 34 deletions(-) diff --git a/chatlearn/models/vllm/hooks/input_preprocess.py b/chatlearn/models/vllm/hooks/input_preprocess.py index 2feba476..95f5b7f2 100644 --- a/chatlearn/models/vllm/hooks/input_preprocess.py +++ b/chatlearn/models/vllm/hooks/input_preprocess.py @@ -67,44 +67,11 @@ def _prompt_to_llm_inputs( ) return token_inputs( + prompt=tokens_content["prompt"], prompt_token_ids=prompt_token_ids, token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) - def _prompt_to_llm_inputs_( - self, - prompt, - request_id, - lora_request=None): - ''' - Extract the components of any single encoder or decoder input prompt. - - Arguments: - - * request_id - * prompt: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * prompt - * prompt_token_ids - * multi_modal_data - * mm_processor_kwargs (request-level input processor/mapper overrides) - ''' - parsed = parse_singleton_prompt(prompt) - - assert parsed["type"] == "tokens", \ - f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" - - prompt_text = parsed["content"]["prompt"] - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) - preprocess.InputPreprocessor._prompt_to_llm_inputs = _prompt_to_llm_inputs From 4149f43f79dd3fcd4e382890481c796a8a66f0f5 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 23 Jan 2025 17:19:43 +0800 Subject: [PATCH 06/19] be compatible with vllm-0.6.3 --- chatlearn/models/vllm/hooks/__init__.py | 39 ++- .../models/vllm/hooks/input_preprocess.py | 77 ------ .../models/vllm/hooks/vllm_0_3_0/__init__.py | 21 ++ .../vllm/hooks/{ => vllm_0_3_0}/sampler.py | 0 .../models/vllm/hooks/vllm_0_5_1/__init__.py | 23 ++ .../vllm/hooks/vllm_0_5_1/llm_engine.py | 30 +++ .../{ => vllm_0_5_1}/logits_processor.py | 0 .../vllm/hooks/{ => vllm_0_5_1}/worker.py | 0 .../models/vllm/hooks/vllm_0_6_3/__init__.py | 29 +++ .../vllm/hooks/vllm_0_6_3/async_llm_engine.py | 54 +++++ .../{ => vllm_0_6_3}/format_device_name.py | 0 .../vllm/hooks/vllm_0_6_3/input_preprocess.py | 58 +++++ chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py | 87 +++++++ .../vllm/hooks/vllm_0_6_3/llm_engine.py | 30 +++ .../models/vllm/hooks/vllm_0_6_3/loader.py | 114 +++++++++ .../vllm/hooks/vllm_0_6_3/logits_processor.py | 42 ++++ .../vllm/hooks/vllm_0_6_3/ray_gpu_executor.py | 222 ++++++++++++++++++ .../vllm/hooks/vllm_0_6_3/worker_base.py | 43 ++++ .../models/vllm/hooks/vllm_0_6_6/__init__.py | 27 +++ .../{ => vllm_0_6_6}/async_llm_engine.py | 5 +- .../vllm/hooks/vllm_0_6_6/input_preprocess.py | 74 ++++++ .../models/vllm/hooks/{ => vllm_0_6_6}/llm.py | 0 .../vllm/hooks/{ => vllm_0_6_6}/llm_engine.py | 0 .../vllm/hooks/{ => vllm_0_6_6}/loader.py | 0 .../{ => vllm_0_6_6}/ray_gpu_executor.py | 0 .../hooks/{ => vllm_0_6_6}/worker_base.py | 0 chatlearn/models/vllm_module_v2.py | 20 +- chatlearn/runtime/dist_actor.py | 22 +- chatlearn/utils/constant.py | 4 +- chatlearn/utils/vllm_import_helper.py | 16 +- chatlearn/utils/vllm_utils.py | 6 +- 31 files changed, 908 insertions(+), 135 deletions(-) delete mode 100644 chatlearn/models/vllm/hooks/input_preprocess.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py rename chatlearn/models/vllm/hooks/{ => vllm_0_3_0}/sampler.py (100%) create mode 100644 chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_5_1/llm_engine.py rename chatlearn/models/vllm/hooks/{ => vllm_0_5_1}/logits_processor.py (100%) rename chatlearn/models/vllm/hooks/{ => vllm_0_5_1}/worker.py (100%) create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py rename chatlearn/models/vllm/hooks/{ => vllm_0_6_3}/format_device_name.py (100%) create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/ray_gpu_executor.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_3/worker_base.py create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py rename chatlearn/models/vllm/hooks/{ => vllm_0_6_6}/async_llm_engine.py (92%) create mode 100644 chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py rename chatlearn/models/vllm/hooks/{ => vllm_0_6_6}/llm.py (100%) rename chatlearn/models/vllm/hooks/{ => vllm_0_6_6}/llm_engine.py (100%) rename chatlearn/models/vllm/hooks/{ => vllm_0_6_6}/loader.py (100%) rename chatlearn/models/vllm/hooks/{ => vllm_0_6_6}/ray_gpu_executor.py (100%) rename chatlearn/models/vllm/hooks/{ => vllm_0_6_6}/worker_base.py (100%) diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 40e156eb..026852f3 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -19,27 +19,18 @@ from .. import is_vllm_v2 -if is_vllm_v2(): - if importlib.util.find_spec("vllm"): - from . import ray_gpu_executor - from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: - from chatlearn.models.vllm.hooks import input_preprocess - from chatlearn.models.vllm.hooks import async_llm_engine - from chatlearn.models.vllm.hooks import llm - from chatlearn.models.vllm.hooks import llm_engine - from chatlearn.models.vllm.hooks import loader - from chatlearn.models.vllm.hooks import worker_base -else: - if importlib.util.find_spec("vllm"): - import vllm - from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion # pylint: disable=ungrouped-imports - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: - from chatlearn.models.vllm.hooks import sampler - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: - from chatlearn.models.vllm.hooks import llm_engine, logits_processor - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1: - from chatlearn.models.vllm.hooks import worker - else: - from chatlearn.models.vllm.hooks import input_preprocess - from chatlearn.models.vllm.hooks import format_device_name +if importlib.util.find_spec("vllm"): + + from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion + + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: + from chatlearn.models.vllm.hooks.vllm_0_3_0 import * + elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1: + from chatlearn.models.vllm.hooks.vllm_0_5_1 import * + elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + from chatlearn.models.vllm.hooks.vllm_0_6_3 import * + elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_6: + from .vllm_0_6_6 import * + else: + raise RuntimeError( + f"vLLM version expected in {list(member.value for member in VLLMVersion)}, while {CURRENT_VLLM_VERSION}.") diff --git a/chatlearn/models/vllm/hooks/input_preprocess.py b/chatlearn/models/vllm/hooks/input_preprocess.py deleted file mode 100644 index 95f5b7f2..00000000 --- a/chatlearn/models/vllm/hooks/input_preprocess.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hooks of vllm-0.6.3 input preprocess to pass prompt text.""" - - -import inspect - -# pylint: disable=unused-import,unused-argument -from vllm.inputs import preprocess -from vllm.inputs.data import token_inputs - -source = inspect.getsource(preprocess.InputPreprocessor._prompt_to_llm_inputs) -if 'parsed = parse_singleton_prompt(prompt)' in source: - from vllm.inputs.parse import parse_singleton_prompt - - - def _prompt_to_llm_inputs( - self, - prompt, - request_id: str, - lora_request=None, - ): - """ - Extract the singleton inputs from a prompt. - - Arguments: - - * request_id - * prompt: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * :class:`SingletonInputs` instance - """ - parsed = parse_singleton_prompt(prompt) - - assert parsed["type"] == "tokens", \ - f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" - - if parsed["type"] == "tokens": - tokens_content = parsed["content"] - - prompt_token_ids = tokens_content["prompt_token_ids"] - token_type_ids = tokens_content.get("token_type_ids") - multi_modal_data = tokens_content.get("multi_modal_data") - mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") - - if multi_modal_data is not None and self._can_process_multimodal(): - return self._process_multimodal( - prompt_token_ids, - multi_modal_data, - mm_processor_kwargs, - lora_request=lora_request, - ) - - return token_inputs( - prompt=tokens_content["prompt"], - prompt_token_ids=prompt_token_ids, - token_type_ids=token_type_ids, - multi_modal_data=multi_modal_data, - mm_processor_kwargs=mm_processor_kwargs, - ) - - preprocess.InputPreprocessor._prompt_to_llm_inputs = _prompt_to_llm_inputs diff --git a/chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py new file mode 100644 index 00000000..5cc844e2 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.3.0.""" + +from ... import is_vllm_v2 + +assert not is_vllm_v2(), "vLLM-0.3.0 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`." + +from . import sampler diff --git a/chatlearn/models/vllm/hooks/sampler.py b/chatlearn/models/vllm/hooks/vllm_0_3_0/sampler.py similarity index 100% rename from chatlearn/models/vllm/hooks/sampler.py rename to chatlearn/models/vllm/hooks/vllm_0_3_0/sampler.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py new file mode 100644 index 00000000..e9691dee --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.5.1.""" + +from ... import is_vllm_v2 + +assert not is_vllm_v2(), "vLLM-0.5.1 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`." + +from . import llm_engine +from . import logits_processor +from . import worker diff --git a/chatlearn/models/vllm/hooks/vllm_0_5_1/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/llm_engine.py new file mode 100644 index 00000000..bff1d188 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_5_1/llm_engine.py @@ -0,0 +1,30 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.5.1 llm_engine remove __reduce__ function.""" + +import inspect + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine import llm_engine + + +source = inspect.getsource(llm_engine.LLMEngine.__reduce__) +if 'RuntimeError' in source: + def __reduce__(self): + # This is to ensure that the LLMEngine can be referenced in + # the closure used to initialize Ray worker actors + pass + + del llm_engine.LLMEngine.__reduce__ diff --git a/chatlearn/models/vllm/hooks/logits_processor.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/logits_processor.py similarity index 100% rename from chatlearn/models/vllm/hooks/logits_processor.py rename to chatlearn/models/vllm/hooks/vllm_0_5_1/logits_processor.py diff --git a/chatlearn/models/vllm/hooks/worker.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/worker.py similarity index 100% rename from chatlearn/models/vllm/hooks/worker.py rename to chatlearn/models/vllm/hooks/vllm_0_5_1/worker.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py new file mode 100644 index 00000000..efd99405 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.6.3.""" + +from ... import is_vllm_v2 +from . import format_device_name +from . import input_preprocess + +if is_vllm_v2(): + from . import async_llm_engine + from . import llm + from . import loader + from . import ray_gpu_executor + from . import worker_base +else: + from . import llm_engine + from . import logits_processor diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py new file mode 100644 index 00000000..77f2ed70 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py @@ -0,0 +1,54 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""del init_ray_cluster in AsyncLLMEngine.""" + +from typing import Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument,not-callable +from vllm.config import EngineConfig +from vllm.engine import async_llm_engine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.usage.usage_lib import UsageContext + +@classmethod +def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, +) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + if engine_config is None: + engine_config = engine_args.create_engine_config() + + executor_class = cls._get_executor_cls(engine_config) + + # Create the async LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + +async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/format_device_name.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/format_device_name.py similarity index 100% rename from chatlearn/models/vllm/hooks/format_device_name.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/format_device_name.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py new file mode 100644 index 00000000..8015972f --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py @@ -0,0 +1,58 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 input preprocess to pass prompt text.""" + + +import inspect + +# pylint: disable=unused-import,unused-argument +from vllm.inputs import preprocess +from vllm.inputs.parse import parse_singleton_prompt + +def extract_prompt_components( + self, + prompt, + request_id, + lora_request=None): + ''' + Extract the components of any single encoder or decoder input prompt. + + Arguments: + + * request_id + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + * mm_processor_kwargs (request-level input processor/mapper overrides) + ''' + parsed = parse_singleton_prompt(prompt) + + assert parsed["type"] == "tokens", \ + f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" + + prompt_text = parsed["content"]["prompt"] + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) + +preprocess.InputPreprocessor._extract_prompt_components = extract_prompt_components diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py new file mode 100644 index 00000000..8824e940 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py @@ -0,0 +1,87 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs.""" + +from typing import Any, Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints import llm +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter + +def init( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, +) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, + **kwargs, + ) + + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS).engine + self.request_counter = Counter() + +llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py new file mode 100644 index 00000000..bff1d188 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py @@ -0,0 +1,30 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.5.1 llm_engine remove __reduce__ function.""" + +import inspect + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine import llm_engine + + +source = inspect.getsource(llm_engine.LLMEngine.__reduce__) +if 'RuntimeError' in source: + def __reduce__(self): + # This is to ensure that the LLMEngine can be referenced in + # the closure used to initialize Ray worker actors + pass + + del llm_engine.LLMEngine.__reduce__ diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py new file mode 100644 index 00000000..5a81f452 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py @@ -0,0 +1,114 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 loader to load ckpt of megatron format.""" + + +import torch + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.model_executor.model_loader import loader +from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model +from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models import llama +from vllm.model_executor.models import qwen2, qwen2_moe + +from chatlearn.utils.vllm_import_helper import LlamaForCausalLM +from chatlearn.utils.vllm_import_helper import QWenLMHeadModel +from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM +from chatlearn.utils.vllm_import_helper import Qwen2MoeForCausalLM +from chatlearn.utils.vllm_import_helper import get_model_architecture +from chatlearn.utils.utils import get_use_legacy_models + +from chatlearn.utils.vllm_utils import ( + convert_llama_state_dict_from_megatron_to_vllm, + convert_llama_state_dict_from_mcore_to_vllm, + convert_qwen_state_dict_from_megatron_to_vllm, + load_checkpoint +) + +def load_weights(self, model_args): + torch.distributed.barrier() + self.model_args = model_args + load_checkpoint(self, None, None, model_args=model_args) + torch.distributed.barrier() + +def load_state_dict(self, state_dict, strict=True, assign=False): + qwen_version = None + if isinstance(self, LlamaForCausalLM): + use_legacy_models = get_use_legacy_models(self.model_args) + if use_legacy_models: + convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm + else: + convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm + elif isinstance(self, QWenLMHeadModel): + qwen_version = 1.0 + convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm + elif isinstance(self, Qwen2ForCausalLM) or (Qwen2MoeForCausalLM is not None and isinstance(self, Qwen2MoeForCausalLM)): + qwen_version = 2.0 + convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm + else: + raise RuntimeError(f"Unsupported model for vllm backend. \ + support [LlamaForCausalLM, QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM] only, while {self}") + + state_dict = convert_state_dict_internal(self.model_args, self.config, qwen_version=qwen_version) + super(type(self), self).load_state_dict(state_dict, strict=strict) + + +def init(self, load_config): + # remove 'Model loader extra config' assert. + self.load_config = load_config + +loader.DummyModelLoader.__init__ = init + +# add ckpt loading of megatron format +def load_model(self, *, model_config, + device_config, + lora_config, + parallel_config, + scheduler_config, + cache_config): + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config, + scheduler_config) + if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ + self.load_config.model_loader_extra_config["load"] is not None: + qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict + qwen2.Qwen2ForCausalLM.load_weights = load_weights + qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict + qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights + model.load_weights(self.load_config.model_loader_extra_config) + else: + # For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) + return model.eval() +loader.DummyModelLoader.load_model = load_model diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py new file mode 100644 index 00000000..d713bff6 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py @@ -0,0 +1,42 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.5.1 logits_processor to allgather logits of all ranks.""" + +import inspect + +# pylint: disable=wildcard-import,ungrouped-imports +from vllm.model_executor.layers import logits_processor + + +source = inspect.getsource(logits_processor.LogitsProcessor._get_logits) +if 'tensor_model_parallel_gather' in source: + import torch + from typing import Optional + from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding + def _get_logits(self, hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + from vllm.distributed.communication_op import tensor_model_parallel_all_gather # pylint: disable=import-outside-toplevel + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + logits_processor.LogitsProcessor._get_logits = _get_logits diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/ray_gpu_executor.py new file mode 100644 index 00000000..f7269930 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/ray_gpu_executor.py @@ -0,0 +1,222 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hook _init_workers_ray""" + +from collections import defaultdict +from typing import Dict, List, Optional + +from vllm import envs +from vllm.executor.ray_gpu_executor import RayGPUExecutor +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, + get_ip, get_open_port, get_vllm_instance_id) + +from chatlearn.utils.global_vars import get_vllm_actors + +logger = init_logger(__name__) + + +# modified based on https://github.com/vllm-project/vllm/blob/6aa6020f9bd4c1e414c10f7bd3a7c2555f1950b2/vllm/executor/ray_gpu_executor.py#L109 +def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): # pylint: disable=unused-argument + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + driver_ip = get_ip() + # driver_actor_id = ray.get_runtime_context().get_actor_id() + vllm_workers = get_vllm_actors() + worker_wrapper_kwargs = self._get_worker_wrapper_args() + if self.use_ray_spmd_worker: + self.workers = vllm_workers + else: + for worker in vllm_workers: + # we cannot call remote func of actor if the actor is its self + worker_ip = ray.get(worker.get_node_ip.remote()) + # if worker._actor_id.hex() == driver_actor_id and self.driver_dummy_worker is None: + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + **worker_wrapper_kwargs) + else: + # Else, added to the list of workers. + self.workers.append(worker) + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", + use_dummy_driver=True) + # worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids") + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") + + VLLM_INSTANCE_ID = get_vllm_instance_id() + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_INSTANCE_ID": + VLLM_INSTANCE_ID, + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + },) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + # local_rank=node_workers[node_id].index(rank), + local_rank=0, + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + + +RayGPUExecutor._init_workers_ray = _init_workers_ray diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/worker_base.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/worker_base.py new file mode 100644 index 00000000..658cfc22 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/worker_base.py @@ -0,0 +1,43 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 worker_base to update execute_method.""" + +# pylint: disable=unused-import,wildcard-import +from vllm.worker import worker_base +from vllm.worker.worker_base import logger + + +def execute_method(self, method, *args, **kwargs): + try: + if self.worker is None: + target = self + else: + if hasattr(self.worker, method): + target = self.worker + else: + target = self + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + +worker_base.WorkerWrapperBase.execute_method = execute_method diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py new file mode 100644 index 00000000..8e15c60c --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.6.6.""" + +from ... import is_vllm_v2 + +assert is_vllm_v2(), "vLLM-0.6.6 only supports vLLM Module v2." + +from . import async_llm_engine +from . import input_preprocess +from . import llm +from . import llm_engine +from . import loader +from . import ray_gpu_executor +from . import worker_base diff --git a/chatlearn/models/vllm/hooks/async_llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py similarity index 92% rename from chatlearn/models/vllm/hooks/async_llm_engine.py rename to chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py index c02699e9..d8dd7386 100644 --- a/chatlearn/models/vllm/hooks/async_llm_engine.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine.""" +"""del init_ray_cluster in AsyncLLMEngine.""" from typing import Dict, Optional @@ -39,9 +39,6 @@ def from_engine_args( executor_class = cls._get_executor_cls(engine_config) - # if executor_class.uses_ray: - # initialize_ray_cluster(engine_config.parallel_config) - # Create the async LLM engine. engine = cls( vllm_config=engine_config, diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py new file mode 100644 index 00000000..c03a1456 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py @@ -0,0 +1,74 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.6 input preprocess to pass prompt text.""" + + +import inspect + +# pylint: disable=unused-import,unused-argument +from vllm.inputs import preprocess +from vllm.inputs.data import token_inputs +from vllm.inputs.parse import parse_singleton_prompt + + +def _prompt_to_llm_inputs( + self, + prompt, + request_id: str, + lora_request=None, +): + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * request_id + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * :class:`SingletonInputs` instance + """ + parsed = parse_singleton_prompt(prompt) + + assert parsed["type"] == "tokens", \ + f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + + return token_inputs( + prompt=tokens_content["prompt"], + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + +preprocess.InputPreprocessor._prompt_to_llm_inputs = _prompt_to_llm_inputs diff --git a/chatlearn/models/vllm/hooks/llm.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py similarity index 100% rename from chatlearn/models/vllm/hooks/llm.py rename to chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py diff --git a/chatlearn/models/vllm/hooks/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py similarity index 100% rename from chatlearn/models/vllm/hooks/llm_engine.py rename to chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py similarity index 100% rename from chatlearn/models/vllm/hooks/loader.py rename to chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py diff --git a/chatlearn/models/vllm/hooks/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py similarity index 100% rename from chatlearn/models/vllm/hooks/ray_gpu_executor.py rename to chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py diff --git a/chatlearn/models/vllm/hooks/worker_base.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/worker_base.py similarity index 100% rename from chatlearn/models/vllm/hooks/worker_base.py rename to chatlearn/models/vllm/hooks/vllm_0_6_6/worker_base.py diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index a720b28a..84596218 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -22,11 +22,10 @@ from transformers import AutoTokenizer from vllm import SamplingParams from vllm.config import LoadFormat -from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper -from vllm.usage.usage_lib import UsageContext +from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion from chatlearn.utils.global_vars import set_vllm_actors from chatlearn.utils.vllm_import_helper import parallel_state from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank @@ -51,9 +50,13 @@ def __init__(self, *args, **kwargs): assert common_methods == {'__init__'}, \ f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" self.local_rank = 0 - if 'vllm_actor_type' in kwargs and 'worker' == kwargs['vllm_actor_type']: - vllm_config = self.init_engine_args() - RayWorkerWrapper.__init__(self, vllm_config=vllm_config) # pylint: disable=non-parent-init-called + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: + RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called + else: + if 'vllm_actor_type' in kwargs and 'worker' == kwargs['vllm_actor_type']: + vllm_config = self.init_engine_args() + RayWorkerWrapper.__init__(self, vllm_config=vllm_config) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() @@ -95,7 +98,10 @@ def init_engine_args(self): seed = self.model_args.get("seed", 0) + self.replica_id else: seed = self.model_args.get("seed", 0) - + + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.usage.usage_lib import UsageContext + self.engine_args = AsyncEngineArgs( model=self.model_args['tokenizer'], tokenizer=self.model_args['tokenizer'], @@ -122,7 +128,6 @@ def init_engine_args(self): distributed_executor_backend="ray") return self.engine_args.create_engine_config(usage_context=UsageContext.ENGINE_CONTEXT) - def init(self): """ :meta private: @@ -335,7 +340,6 @@ def preprocess_inputs(self, query, is_eval): return parsed_prompts, sampling_params def run_vllm(self, parsed_prompts, sampling_params): - # breakpoint() outputs = self.llm.generate( parsed_prompts, sampling_params, diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index 5ab4e810..8e7e9ee8 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -234,15 +234,19 @@ def __init__(self, *args, **kwargs): self.vllm_engine = None def create_actor(self, num_gpus, placement_group, group_index): - # kwargs = { - # "worker_module_name": "vllm.worker.worker", - # "worker_class_name": "Worker", - # "worker_class_fn": None, - # "trust_remote_code": True, - # } - kwargs = { - "vllm_actor_type" : "worker" - } + from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion + + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + kwargs = { + "worker_module_name": "vllm.worker.worker", + "worker_class_name": "Worker", + "worker_class_fn": None, + "trust_remote_code": True, + } + else: + kwargs = { + "vllm_actor_type" : "worker" + } self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) def create_engine_actor(self, num_gpus, placement_group, group_index): diff --git a/chatlearn/utils/constant.py b/chatlearn/utils/constant.py index c1dde2fe..2bae4dc7 100644 --- a/chatlearn/utils/constant.py +++ b/chatlearn/utils/constant.py @@ -38,8 +38,8 @@ class VLLMVersion(str, Enum): """support versions of vLLM.""" v_0_3_0 = "0.3.0" v_0_5_1 = "0.5.1" - # v_0_6_3 = "0.6.3" - v_0_6_3 = "0.6.6" + v_0_6_3 = "0.6.3" + v_0_6_6 = "0.6.6" class QwenVersion(float, Enum): diff --git a/chatlearn/utils/vllm_import_helper.py b/chatlearn/utils/vllm_import_helper.py index ec4ccfc5..d618bf7a 100644 --- a/chatlearn/utils/vllm_import_helper.py +++ b/chatlearn/utils/vllm_import_helper.py @@ -47,8 +47,8 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.detokenizer import Detokenizer -elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3]: - # imports for vllm-063 +elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: + # imports for vllm-063/-66 from vllm.core.interfaces import BlockSpaceManager from vllm.distributed import parallel_state from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -56,8 +56,8 @@ from vllm.distributed.parallel_state import initialize_model_parallel from vllm.distributed.utils import get_pp_indices from vllm.engine.async_llm_engine import _AsyncLLMEngine as LLMEngine - # if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: - # from vllm.engine.llm_engine import _load_generation_config_dict + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + from vllm.engine.llm_engine import _load_generation_config_dict from vllm.engine.llm_engine import SchedulerContext, SchedulerOutputState from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker @@ -102,7 +102,7 @@ def get_block_manager_cls(version): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: return BlockSpaceManager - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: return BlockSpaceManager.get_block_space_manager_class(version) @@ -111,7 +111,7 @@ def get_model_architecture(config): from vllm.model_executor.model_loader import _get_model_architecture as get_model_architecture_v1 return get_model_architecture_v1(config) - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: from vllm.model_executor.model_loader.utils import get_model_architecture as get_model_architecture_v2 return get_model_architecture_v2(config)[0] @@ -120,7 +120,7 @@ def get_pipeline_model_parallel_rank(): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: return parallel_state.get_pipeline_model_parallel_rank() - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: return parallel_state.get_pp_group().rank_in_group @@ -128,5 +128,5 @@ def get_pipeline_model_parallel_world_size(): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: return parallel_state.get_pipeline_model_parallel_world_size() - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: return parallel_state.get_pp_group().world_size diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index d6622d0c..1bedee15 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -608,7 +608,7 @@ def _init_distributed_environment(args): world_size=args.world_size, rank=args.rank, timeout=timedelta(minutes=args.distributed_timeout_minutes)) - if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: _WORLD = None if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) @@ -874,7 +874,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Transformer Layers print("Converting transformer layers") - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: start_layer_idx, _ = get_pp_indices( hf_config.num_hidden_layers, pp_rank, @@ -1239,7 +1239,7 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version= # Transformer Layers print("Converting transformer layers") - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: start_layer_idx, _ = get_pp_indices( hf_config.num_hidden_layers, pp_rank, From efbef5686d748ac9f27fb9938bac868f4747deb9 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 23 Jan 2025 17:43:45 +0800 Subject: [PATCH 07/19] fix pylint. --- .../vllm/hooks/vllm_0_6_3/input_preprocess.py | 3 - chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py | 102 ++-- .../vllm/hooks/vllm_0_6_6/llm_engine.py | 7 +- .../models/vllm/hooks/vllm_0_6_6/loader.py | 64 +-- .../vllm/hooks/vllm_0_6_6/ray_gpu_executor.py | 434 +++++++++--------- chatlearn/models/vllm_module_v2.py | 6 +- chatlearn/runtime/dist_actor.py | 3 +- 7 files changed, 306 insertions(+), 313 deletions(-) diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py index 8015972f..dbb80d69 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py @@ -14,9 +14,6 @@ # ============================================================================== """Hooks of vllm-0.6.3 input preprocess to pass prompt text.""" - -import inspect - # pylint: disable=unused-import,unused-argument from vllm.inputs import preprocess from vllm.inputs.parse import parse_singleton_prompt diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py index a613d608..1028bf0c 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py @@ -23,8 +23,7 @@ from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter -def init( - self, +def init(self, model: str, tokenizer: Optional[str] = None, tokenizer_mode: str = "auto", @@ -50,62 +49,61 @@ def init( task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, compilation_config: Optional[Union[int, Dict[str, Any]]] = None, - **kwargs, - ) -> None: - ''' - LLM constructor. + **kwargs,) -> None: + ''' + LLM constructor. - Note: if enforce_eager is unset (enforce_eager is None) - it defaults to False. - ''' + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' - if "disable_log_stats" not in kwargs: - kwargs["disable_log_stats"] = True + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True - if compilation_config is not None: - if isinstance(compilation_config, (int, dict)): - compilation_config_instance = CompilationConfig.from_cli( - str(compilation_config)) - else: - compilation_config_instance = compilation_config + if compilation_config is not None: + if isinstance(compilation_config, (int, dict)): + compilation_config_instance = CompilationConfig.from_cli( + str(compilation_config)) else: - compilation_config_instance = None + compilation_config_instance = compilation_config + else: + compilation_config_instance = None - engine_args = AsyncEngineArgs( - model=model, - task=task, - tokenizer=tokenizer, - tokenizer_mode=tokenizer_mode, - skip_tokenizer_init=skip_tokenizer_init, - trust_remote_code=trust_remote_code, - allowed_local_media_path=allowed_local_media_path, - tensor_parallel_size=tensor_parallel_size, - dtype=dtype, - quantization=quantization, - revision=revision, - tokenizer_revision=tokenizer_revision, - seed=seed, - gpu_memory_utilization=gpu_memory_utilization, - swap_space=swap_space, - cpu_offload_gb=cpu_offload_gb, - enforce_eager=enforce_eager, - max_seq_len_to_capture=max_seq_len_to_capture, - disable_custom_all_reduce=disable_custom_all_reduce, - disable_async_output_proc=disable_async_output_proc, - hf_overrides=hf_overrides, - mm_processor_kwargs=mm_processor_kwargs, - override_pooler_config=override_pooler_config, - compilation_config=compilation_config_instance, - **kwargs, - ) - # Logic to switch between engines is done at runtime instead of import - # to avoid import order issues - self.engine_class = self.get_engine_class() + engine_args = AsyncEngineArgs( + model=model, + task=task, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + # Logic to switch between engines is done at runtime instead of import + # to avoid import order issues + self.engine_class = self.get_engine_class() - # TODO(rob): enable mp by default (issue with fork vs spawn) - self.llm_engine = self.engine_class.from_engine_args( - engine_args, usage_context=UsageContext.LLM_CLASS) + # TODO(rob): enable mp by default (issue with fork vs spawn) + self.llm_engine = self.engine_class.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) - self.request_counter = Counter() + self.request_counter = Counter() llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py index ff2c3b4d..ea24cc86 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py @@ -20,6 +20,7 @@ # pylint: disable=unused-import,wildcard-import,unused-argument from vllm.engine import llm_engine from vllm.engine.metrics_types import StatLoggerBase +from vllm.executor.ray_gpu_executor import RayGPUExecutor from vllm.usage.usage_lib import UsageContext @@ -43,11 +44,9 @@ def from_engine_args( """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_config = engine_args.create_engine_config(usage_context) - # executor_class = cls._get_executor_cls(engine_config) - from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor # Create the LLM engine. - engine = cls( + engine = cls( # pylint: disable=not-callable vllm_config=engine_config, executor_class=executor_class, log_stats=not engine_args.disable_log_stats, @@ -56,4 +55,4 @@ def from_engine_args( ) return engine -llm_engine.LLMEngine.from_engine_args = from_engine_args \ No newline at end of file +llm_engine.LLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py index 17af8e37..6d46c572 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py @@ -77,39 +77,39 @@ def init(self, load_config): # add ckpt loading of megatron format def load_model(self, vllm_config: VllmConfig):# -> nn.Module: - device_config = vllm_config.device_config - model_config = vllm_config.model_config - with set_default_torch_dtype(model_config.dtype): - with torch.device(device_config.device): - model = _initialize_model(vllm_config=vllm_config) - # NOTE(woosuk): For accurate performance evaluation, we assign + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ + self.load_config.model_loader_extra_config["load"] is not None: + qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict + qwen2.Qwen2ForCausalLM.load_weights = load_weights + qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict + qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights + model.load_weights(self.load_config.model_loader_extra_config) + else: + # For accurate performance evaluation, we assign # random values to the weights. - if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ - self.load_config.model_loader_extra_config["load"] is not None: - qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict - qwen2.Qwen2ForCausalLM.load_weights = load_weights - qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict - qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights - llama.LlamaForCausalLM.load_state_dict = load_state_dict - llama.LlamaForCausalLM.load_weights = load_weights - #model.load_weights(self.load_config.model_loader_extra_config) - else: - # For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model) - - for _, module in model.named_modules(): - quant_method = getattr(module, "quant_method", None) - if quant_method is not None: - # When quant methods need to process weights after loading - # (for repacking, quantizing, etc), they expect parameters - # to be on the global target device. This scope is for the - # case where cpu offloading is used, where we will move the - # parameters onto device for processing and back off after. - with device_loading_context( - module, torch.device(device_config.device)): - quant_method.process_weights_after_loading(module) - return model.eval() + initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) + return model.eval() loader.DummyModelLoader.load_model = load_model diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py index 78f02fd2..c27e27c4 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py @@ -31,223 +31,223 @@ # modified based on https://github.com/vllm-project/vllm/blob/6aa6020f9bd4c1e414c10f7bd3a7c2555f1950b2/vllm/executor/ray_gpu_executor.py#L109 def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): - if (self.parallel_config.tensor_parallel_size == 1 - and self.parallel_config.pipeline_parallel_size == 1): - # For single GPU case, we use a ray worker with constrained memory. - num_gpus = self.cache_config.gpu_memory_utilization + **ray_remote_kwargs): # pylint: disable=unused-argument + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): + # For single GPU case, we use a ray worker with constrained memory. + num_gpus = self.cache_config.gpu_memory_utilization + else: + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + driver_ip = get_ip() + workers = [] + workers = get_vllm_actors() + # if self.use_ray_spmd_worker: + # workers = vllm_workers + # else: + # for bundle_id, bundle in enumerate(placement_group.bundle_specs): + # if not bundle.get("GPU", 0): + # continue + # scheduling_strategy = PlacementGroupSchedulingStrategy( + # placement_group=placement_group, + # placement_group_capture_child_tasks=True, + # placement_group_bundle_index=bundle_id, + # ) + + # worker = ray.remote( + # num_cpus=0, + # num_gpus=num_gpus, + # scheduling_strategy=scheduling_strategy, + # **ray_remote_kwargs, + # )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) + # workers.append(worker) + + worker_ip_refs = [ + worker.get_node_ip.remote() # type: ignore[attr-defined] + for worker in workers + ] + worker_ips = ray.get(worker_ip_refs) + + if not self.use_ray_spmd_worker: + for i in range(len(workers)): # pylint: disable=consider-using-enumerate + worker = workers[i] + worker_ip = worker_ips[i] + if self.driver_dummy_worker is None and worker_ip == driver_ip: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + vllm_config=self.vllm_config) + workers.pop(i) + worker_ips.pop(i) + self.workers = workers + break + else: + self.workers = workers + + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + worker_to_ip = dict(zip(self.workers, worker_ips)) + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = worker_to_ip[worker] + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" + " each node.") + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=0,#node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) else: - # Otherwise, the ray workers are allocated with a full GPU. - num_gpus = 1 - - # The driver dummy worker does not actually use any resources. - # It holds the resource for the driver worker. - self.driver_dummy_worker: Optional[RayWorkerWrapper] = None - # The remaining workers are the actual ray actors. - self.workers: List[RayWorkerWrapper] = [] - - # Used in ray compiled DAG: indexed first by PP rank, - # and then TP rank. In other words, the inner list is - # the TP group of workers for a PP rank. - self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] - - if self.parallel_config.ray_workers_use_nsight: - ray_remote_kwargs = self._configure_ray_workers_use_nsight( - ray_remote_kwargs) - - logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) - - # Create the workers. - driver_ip = get_ip() - workers = [] - workers = get_vllm_actors() - # if self.use_ray_spmd_worker: - # workers = vllm_workers - # else: - # for bundle_id, bundle in enumerate(placement_group.bundle_specs): - # if not bundle.get("GPU", 0): - # continue - # scheduling_strategy = PlacementGroupSchedulingStrategy( - # placement_group=placement_group, - # placement_group_capture_child_tasks=True, - # placement_group_bundle_index=bundle_id, - # ) - - # worker = ray.remote( - # num_cpus=0, - # num_gpus=num_gpus, - # scheduling_strategy=scheduling_strategy, - # **ray_remote_kwargs, - # )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) - # workers.append(worker) - - worker_ip_refs = [ - worker.get_node_ip.remote() # type: ignore[attr-defined] - for worker in workers - ] - worker_ips = ray.get(worker_ip_refs) - - if not self.use_ray_spmd_worker: - for i in range(len(workers)): - worker = workers[i] - worker_ip = worker_ips[i] - if self.driver_dummy_worker is None and worker_ip == driver_ip: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - self.driver_worker = RayWorkerWrapper( - vllm_config=self.vllm_config) - workers.pop(i) - worker_ips.pop(i) - self.workers = workers - break - else: - self.workers = workers - - logger.debug("workers: %s", self.workers) - logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) - if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") - - ip_counts: Dict[str, int] = {} - for ip in worker_ips: - ip_counts[ip] = ip_counts.get(ip, 0) + 1 - - worker_to_ip = dict(zip(self.workers, worker_ips)) - - def sort_by_driver_then_worker_ip(worker): - """ - Sort the workers based on 3 properties: - 1. If the worker is on the same node as the driver (vllm engine), - it should be placed first. - 2. Then, if the worker is on a node with fewer workers, it should - be placed first. - 3. Finally, if the work is on a node with smaller IP address, it - should be placed first. - """ - ip = worker_to_ip[worker] - return (ip != driver_ip, ip_counts[ip], ip) - - # After sorting, the workers on the same node will be - # close to each other, and the workers on the driver - # node will be placed first. - self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) - - # Get the set of GPU IDs used on each node. - worker_node_and_gpu_ids = [] - for worker in [self.driver_dummy_worker] + self.workers: - if worker is None: - # driver_dummy_worker can be None when using ray spmd worker. - continue - worker_node_and_gpu_ids.append( - ray.get(worker.get_node_and_gpu_ids.remote()) \ - ) # type: ignore - - node_workers = defaultdict(list) # node id -> list of worker ranks - node_gpus = defaultdict(list) # node id -> list of gpu ids - - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): - node_workers[node_id].append(i) - # `gpu_ids` can be a list of strings or integers. - # convert them to integers for consistency. - # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), - # string sorting is not sufficient. - # see https://github.com/vllm-project/vllm/issues/5590 - gpu_ids = [int(x) for x in gpu_ids] - node_gpus[node_id].extend(gpu_ids) - for node_id, gpu_ids in node_gpus.items(): - node_gpus[node_id] = sorted(gpu_ids) - - all_ips = set(worker_ips + [driver_ip]) - n_ips = len(all_ips) - n_nodes = len(node_workers) - - if n_nodes != n_ips: - raise RuntimeError( - f"Every node should have a unique IP address. Got {n_nodes}" - f" nodes with node ids {list(node_workers.keys())} and " - f"{n_ips} unique IP addresses {all_ips}. Please check your" - " network configuration. If you set `VLLM_HOST_IP`" - " environment variable, make sure it is unique for" - " each node.") - - # Set environment variables for the driver and workers. - all_args_to_update_environment_variables = [({ - "CUDA_VISIBLE_DEVICES": - ",".join(map(str, node_gpus[node_id])), - "VLLM_TRACE_FUNCTION": - str(envs.VLLM_TRACE_FUNCTION), - **({ - "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND - } if envs.VLLM_ATTENTION_BACKEND is not None else {}) - }, ) for (node_id, _) in worker_node_and_gpu_ids] - - self._env_vars_for_all_workers = ( - all_args_to_update_environment_variables) - - self._run_workers("update_environment_variables", - all_args=self._get_env_vars_to_be_updated()) - - if len(node_gpus) == 1: - # in single node case, we don't need to get the IP address. - # the loopback address is sufficient - # NOTE: a node may have several IP addresses, one for each - # network interface. `get_ip()` might return any of them, - # while they might not work for communication inside the node - # if the network setup is complicated. Using the loopback address - # solves this issue, as it always works for communication inside - # the node. - driver_ip = "127.0.0.1" - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - # Initialize the actual workers inside worker wrapper. - init_worker_all_kwargs = [ - self._get_worker_kwargs( - local_rank=0,#node_workers[node_id].index(rank), - rank=rank, - distributed_init_method=distributed_init_method, - ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) - ] - self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) - - self._run_workers("init_device") - self._run_workers("load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers) - - if self.use_ray_spmd_worker: - for pp_rank in range(self.parallel_config.pipeline_parallel_size): - self.pp_tp_workers.append([]) - for tp_rank in range( - self.parallel_config.tensor_parallel_size): - # PP=2, TP=4 - # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] - rank = (pp_rank * self.parallel_config.tensor_parallel_size - ) + tp_rank - assert len(self.pp_tp_workers[pp_rank]) == tp_rank - assert pp_rank < len(self.pp_tp_workers) - self.pp_tp_workers[pp_rank].append(self.workers[rank]) - - # This is the list of workers that are rank 0 of each TP group EXCEPT - # global rank 0. These are the workers that will broadcast to the - # rest of the workers. - self.tp_driver_workers: List[RayWorkerWrapper] = [] - # This is the list of workers that are not drivers and not the first - # worker in a TP group. These are the workers that will be - # broadcasted to. - self.non_driver_workers: List[RayWorkerWrapper] = [] - - # Enforce rank order for correct rank to return final output. - for index, worker in enumerate(self.workers): - # The driver worker is rank 0 and not in self.workers. - rank = index + 1 - if rank % self.parallel_config.tensor_parallel_size == 0: - self.tp_driver_workers.append(worker) - else: - self.non_driver_workers.append(worker) + self.non_driver_workers.append(worker) RayGPUExecutor._init_workers_ray = _init_workers_ray diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 84596218..568d0ed7 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -24,7 +24,7 @@ from vllm.config import LoadFormat from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper - + from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion from chatlearn.utils.global_vars import set_vllm_actors from chatlearn.utils.vllm_import_helper import parallel_state @@ -99,8 +99,8 @@ def init_engine_args(self): else: seed = self.model_args.get("seed", 0) - from vllm.engine.arg_utils import AsyncEngineArgs - from vllm.usage.usage_lib import UsageContext + from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-outside-toplevel + from vllm.usage.usage_lib import UsageContext # pylint: disable=import-outside-toplevel self.engine_args = AsyncEngineArgs( model=self.model_args['tokenizer'], diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index 8e7e9ee8..6819d879 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -30,6 +30,7 @@ if vllm_exist: from chatlearn.models.vllm_module import VLLMModule from chatlearn.models.vllm_module_v2 import VLLMModuleV2 + from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion RAY_REMOTE = "remote" @@ -234,8 +235,6 @@ def __init__(self, *args, **kwargs): self.vllm_engine = None def create_actor(self, num_gpus, placement_group, group_index): - from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: kwargs = { "worker_module_name": "vllm.worker.worker", From f62c943febd7c353dd30dbe310228842a5e36c84 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 23 Jan 2025 17:47:25 +0800 Subject: [PATCH 08/19] fix pylint. --- .../vllm/hooks/vllm_0_6_6/async_llm_engine.py | 47 +++++++++---------- .../vllm/hooks/vllm_0_6_6/input_preprocess.py | 3 -- .../vllm/hooks/vllm_0_6_6/ray_gpu_executor.py | 2 +- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py index d8dd7386..513a068c 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py @@ -24,31 +24,30 @@ from vllm.usage.usage_lib import UsageContext @classmethod -def from_engine_args( - cls, - engine_args: AsyncEngineArgs, - engine_config: Optional[VllmConfig] = None, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - if engine_config is None: - engine_config = engine_args.create_engine_config(usage_context) +def from_engine_args(cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[VllmConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + if engine_config is None: + engine_config = engine_args.create_engine_config(usage_context) - executor_class = cls._get_executor_cls(engine_config) + executor_class = cls._get_executor_cls(engine_config) - # Create the async LLM engine. - engine = cls( - vllm_config=engine_config, - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine + # Create the async LLM engine. + engine = cls( + vllm_config=engine_config, + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py index c03a1456..a5d17acb 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py @@ -14,9 +14,6 @@ # ============================================================================== """Hooks of vllm-0.6.6 input preprocess to pass prompt text.""" - -import inspect - # pylint: disable=unused-import,unused-argument from vllm.inputs import preprocess from vllm.inputs.data import token_inputs diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py index c27e27c4..19efb2f2 100644 --- a/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py @@ -31,7 +31,7 @@ # modified based on https://github.com/vllm-project/vllm/blob/6aa6020f9bd4c1e414c10f7bd3a7c2555f1950b2/vllm/executor/ray_gpu_executor.py#L109 def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): # pylint: disable=unused-argument + **ray_remote_kwargs): # pylint: disable=unused-argument,unused-variable if (self.parallel_config.tensor_parallel_size == 1 and self.parallel_config.pipeline_parallel_size == 1): # For single GPU case, we use a ray worker with constrained memory. From ddb2181f7bd974901f99a9ce0803a80cac37ea09 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 23 Jan 2025 20:11:38 +0800 Subject: [PATCH 09/19] support replicated kv in param sync --- chatlearn/synchronizer/parameter_sync.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 2267d80f..7141f629 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -1220,8 +1220,8 @@ def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None): return - assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ - f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" + #assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ + # f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: From 74b0e8e40ee0b75022a1884c8267fb36ee719e9e Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Fri, 24 Jan 2025 21:56:07 +0800 Subject: [PATCH 10/19] yes . --- chatlearn/models/base_module.py | 25 ++++- chatlearn/models/vllm_module_v2.py | 6 ++ chatlearn/synchronizer/megatron_vllm.py | 38 +++++-- chatlearn/synchronizer/parameter_sync.py | 131 +++++++++++++++++------ chatlearn/utils/arguments.py | 2 +- 5 files changed, 157 insertions(+), 45 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 8e571e1c..e910c00f 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -751,12 +751,26 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0): def alltoall_routed_expert_parameter(self, pipe_stage=0): assert self._synchronizer is not None + import torch + + comm_group = self.tensor_and_expert_parallel_group() + rank = torch.distributed.get_rank(group=comm_group) + world_size = torch.distributed.get_world_size(group=comm_group) + # with open(f"/workspace/code/cmd/moelite_scripts/{self.name}_{rank}_{world_size}.txt", "a+") as file: + # file.write(f"debug alltoall rank: {rank} in comm group {id(comm_group)}, num_params: {len(self._parameters_to_sync[pipe_stage])}" + "\n") + # breakpoint() for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.alltoall_routed_experts( name, param, - self.tensor_and_expert_parallel_group() + self.tensor_and_expert_parallel_group(), + self.name, + rank, + world_size ) + + # self._logger.info(f"debug {name} {param.shape} state: {state}") + # state = True if state: self._expert_sync_buffer.pop(name, "Not Found.") self._expert_sync_buffer[name] = param @@ -829,6 +843,7 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr parameters_to_sync = self._parameters_to_recv[to_rank] else: parameters_to_sync = self._parameters_to_send + self._logger.info(f"stage2 need to sync params: {len(parameters_to_sync[0])}") else: del self._sync_buffer self._sync_buffer = defaultdict(list) @@ -883,14 +898,17 @@ def tensor_generator(): yield param_data, buffer_num bucket_generator = bucket_tensors_two_stage_generator( - tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb, + tensor_generator, bucket_size_mb=1024,#self.runtime_args.coalesced_buffer_mb, stage2=stage2, tensor_changed=tensor_changed and not stage2 ) dense_bucket_num = 0 sparse_bucket_num = 0 + count = 0 for bucket_or_tensor, is_dense in bucket_generator: if is_dense: index = 0 if stage2 else (to_rank % self.tp_num_mapping) + if stage2: + self._logger.info(f"stage2 bucket: {len(bucket_or_tensor)} count: {count}") all_buffers = coalesced_comm_dense_two_stage( bucket_or_tensor, col.broadcast, rank, extra_args=(src_rank, group_name), tensor_changed=tensor_changed, @@ -903,6 +921,9 @@ def tensor_generator(): del value self._sync_buffer[key] += cpu_value del all_buffers + count += len(bucket_or_tensor) + if stage2: + self._logger.info(f"finished stage2 bucket_or_tensor: {len(bucket_or_tensor)} count: {count}") dense_bucket_num += 1 else: col.broadcast(bucket_or_tensor, src_rank, group_name) diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 568d0ed7..fbc54069 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -405,6 +405,12 @@ def pipeline_parallel_rank(self): """ return get_pipeline_model_parallel_rank() + def tensor_model_parallel_size(self): + return self.tensor_and_expert_model_parallel_size() + + def expert_model_parallel_size(self): + return 1 + def tensor_and_expert_model_parallel_size(self): """ get tensor_and_expert_model_parallel_size diff --git a/chatlearn/synchronizer/megatron_vllm.py b/chatlearn/synchronizer/megatron_vllm.py index ae334851..991d538b 100644 --- a/chatlearn/synchronizer/megatron_vllm.py +++ b/chatlearn/synchronizer/megatron_vllm.py @@ -231,7 +231,7 @@ def allgather_routed_experts(self, name, params_to_sync, group_name, tp_rank): # "Please export `QWEN_VERSION` as `qwen_moe_v1`." ) - def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): + def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group, model_name, rank, world_size): """ This function is applicable for synchronizing parameters from QWen with HEP enabled to vLLM. In HEP, routed experts are split into a total number of EP size * TP size. @@ -240,6 +240,9 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): if self.sync_map._to_alltoall_routed_experts_dict is None: return params_to_sync, False + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"1 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + to_alltoall_routed_experts_dict = self.sync_map._to_alltoall_routed_experts_dict layer_re = to_alltoall_routed_experts_dict["layer_re"] to_regroup_modules_list = to_alltoall_routed_experts_dict["modules"] @@ -248,6 +251,9 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): if m is None: return params_to_sync, False + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"2 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + op_name = m.group(2) if op_name in to_regroup_modules_list: tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] @@ -261,7 +267,7 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): # regroup among difference tp slices param = params_to_sync.view((moe_num_experts, -1, hidden_size)) param = param.reshape((local_num_experts * 2, -1, hidden_size)) - params = list(param.chunk(tp_size, dim=1)) + params = list(param.chunk(hep_size, dim=1)) # reorder w1 and w3 params_list = [] while params: @@ -276,36 +282,54 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): del params_to_sync output = [ torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device) - for i in range(tp_size) + for i in range(hep_size) ] + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"3.1 debug alltoall rank: {rank}/{world_size} vs {torch.distributed.get_rank(group=comm_group)}/{torch.distributed.get_world_size(group=comm_group)} in comm group {id(comm_group)} {name}" + "\n") + torch.distributed.all_to_all(output, params_list, group=comm_group) + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"3.2 debug alltoall rank: {rank}/{world_size} vs {torch.distributed.get_rank(group=comm_group)}/{torch.distributed.get_world_size(group=comm_group)} in comm group {id(comm_group)} {name}" + "\n") + del params_list params_to_sync = torch.cat(output, dim=0).contiguous() del output else: # w2_weight param = params_to_sync.view((local_num_experts, -1, hidden_size)) - params = list(param.chunk(tp_size, dim=1)) + params = list(param.chunk(hep_size, dim=1)) params_list = [ele.contiguous() for ele in params] del param del params del params_to_sync output = [ torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device) - for i in range(tp_size) + for i in range(hep_size) ] + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"4.1 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + torch.distributed.all_to_all(output, params_list, group=comm_group) + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"4.2 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + del params_list params_to_sync = torch.cat(output, dim=0).transpose(1, 2).contiguous() del output + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"f.1 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + return params_to_sync, True else: + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"f.2 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + return params_to_sync, False - def alltoall_routed_experts(self, name, params_to_sync, comm_group): # pylint: disable=unused-argument + def alltoall_routed_experts(self, name, params_to_sync, comm_group, model_name="default", rank=0, world_size=1): # pylint: disable=unused-argument megatron_version = get_megatron_version() if megatron_version == MegatronVersion.V4: - return self.alltoall_routed_experts_from_hep(name, params_to_sync, comm_group) + return self.alltoall_routed_experts_from_hep(name, params_to_sync, comm_group, model_name, rank, world_size) else: raise NotImplementedError( "ChatLearn does not support all-to-all routed experts for Megatron-LM, but supports QWen with HEP enabled. " diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 7141f629..17d084ac 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -14,6 +14,7 @@ # ============================================================================== """Sync parameters""" +import os import concurrent.futures import traceback from collections import defaultdict @@ -85,9 +86,14 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): [ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER, ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL], \ f"Only support 'allgather' or 'alltoall' for routed expert regrouping, while {self._comm_type_to_regroup_routed_experts}" if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL: - if self.num_dst_tensor_parallel != self.num_src_tensor_parallel: - logger.warning("Only support ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL when src tp eqs dst tp, use 'allgather' instead.") + print(f"debug self.num_dst_tensor_parallel: {self.num_dst_tensor_parallel}") + print(f"debug self.num_dst_expert_parallel: {self.num_dst_expert_parallel}") + print(f"debug self.num_src_tensor_parallel: {self.num_src_tensor_parallel}") + print(f"debug self.num_src_expert_parallel: {self.num_src_expert_parallel}") + if self.num_dst_tensor_parallel * self.num_dst_expert_parallel != self.num_src_tensor_parallel * self.num_src_expert_parallel: + logger.info("Only support ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL when src tp eqs dst tp, use 'allgather' instead.") self._comm_type_to_regroup_routed_experts = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER + logger.info(f"Set ROUTED_EXPERT_REGROUPING_COMM_TYPE = {self._comm_type_to_regroup_routed_experts}.") self.sorted_send_actors = None self.sorted_send_actors_stage2 = None self.actor2synchronizer = {} @@ -303,8 +309,8 @@ def build_rank_mapping(self, add_recv_actor_fn=None): if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None): return - assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ - f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" + #assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ + # f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: @@ -594,7 +600,9 @@ def sync_broadcast_two_stage(self, actors, group_name, requires_grad=None, stage send_actor = actors[0] for rank, recv_actor in enumerate(actors[1:]): if stage2: - self.set_sync_param_names_stage2(send_actor, recv_actor, self.actor2rank[recv_actor], requires_grad, filter_fn, param_group) + s_names, d_names = self.set_sync_param_names_stage2(send_actor, recv_actor, self.actor2rank[recv_actor], requires_grad, filter_fn, param_group) + # for s_name, d_name in zip(s_names, d_names): + logger.info(f"stage 2 sync: {len(s_names)} -> {len(d_names)}") else: self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group) pipe_stage = self.get_actor_pipe_rank(send_actor) @@ -661,6 +669,7 @@ def sync_alltoall(self, actors, requires_grad=None, filter_fn=None): self.set_sync_param_names(actor, actor, requires_grad, filter_fn, param_group="routed", should_map_name=False) pipe_stage = self.get_actor_pipe_rank(actors[0]) refs = [] + logger.info(f"debug alltoall among {[self.actor2rank[actor] for actor in actors]}") for actor in actors: ref = actor.alltoall_routed_expert_parameter.remote(pipe_stage) refs.append(ref) @@ -852,10 +861,13 @@ def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): def sync_broadcast_multi_threads( self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None, group_name=None, stage2=False, filter_fn=None, param_group="default"): - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] + log_debug = bool(int(os.getenv("log_debug", "0"))) + + log_debug = False + if log_debug: for send_actor in sorted_send_actors: recv_actors = send_recv_actor_mappings[send_actor] + logger.info(f"sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: if stage2: for idx, recv_actor in enumerate(recv_actors): @@ -863,31 +875,64 @@ def sync_broadcast_multi_threads( actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group ) - futures.append(executor.submit( - self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group - )) + logger.info(f"stage2 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[recv_actor]]}") + + self.sync_broadcast_two_stage(actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group) + logger.info(f"stage2 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[recv_actor]]} finished.") + else: actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, recv_actors, group_name=group_name, param_group=param_group ) - futures.append(executor.submit( - self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group - )) - else: - raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - traceback.print_exc() - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") + self.sync_broadcast_two_stage(actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group) + # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]} finished.") + + else: + if stage2: + message = "second_stage" + else: + message = "first_stage" + max_workers = len(sorted_send_actors) + logger.info(f"Use {max_workers} workers for {message} broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for send_actor in sorted_send_actors: + recv_actors = send_recv_actor_mappings[send_actor] + logger.info(f"sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + if stage2: + for idx, recv_actor in enumerate(recv_actors): + group_name_with_idx = f"{group_name}_{idx}" + actor_groups, finalized_group_name = self.create_broadcast_group( + send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group + ) + futures.append(executor.submit( + self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group + )) + else: + actor_groups, finalized_group_name = self.create_broadcast_group( + send_actor, recv_actors, group_name=group_name, param_group=param_group + ) + futures.append(executor.submit( + self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group + )) + else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) def sync_allgather_multi_threads( self, send_actors, max_workers=1, requires_grad=None, group_name=None, filter_fn=None ): send_actors_to_allgather_routed_experts = send_actors[0] + logger.info(f"Use {max_workers} workers for allgather multiprocessing.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for allgather_actors in send_actors_to_allgather_routed_experts: @@ -906,18 +951,26 @@ def sync_alltoall_multi_threads( self, send_actors, max_workers=1, requires_grad=None, filter_fn=None ): send_actors_to_alltoall_routed_experts = send_actors[0] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] + if False: for actor_groups in send_actors_to_alltoall_routed_experts: - futures.append(executor.submit( - self.sync_alltoall, actor_groups, requires_grad, filter_fn=filter_fn - )) - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + logger.info(f"1 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") + self.sync_alltoall(actor_groups, requires_grad, filter_fn=filter_fn) + logger.info(f"2 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") + else: + logger.info(f"Use {max_workers} workers for alltoall multiprocessing.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for actor_groups in send_actors_to_alltoall_routed_experts: + logger.info(f"1 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") + futures.append(executor.submit( + self.sync_alltoall, actor_groups, requires_grad, filter_fn=filter_fn + )) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) def check_and_setup_collective_group(self): if not self._is_collective_group_created: @@ -1033,6 +1086,7 @@ def _multi_thread_sync_for_tp_num_mapping_gt_1( sorted_send_actors_stage1 = list(actor_mappings_stage1.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage1, actor_mappings_stage1) group_name = self.group_name + "_inter_comm" + logger.info(f"1 debug param sync:start to sync {group_name}") self.sync_broadcast_multi_threads( sorted_send_actors_stage1, actor_mappings_stage1, max_workers, requires_grad, group_name=group_name, stage2=False, filter_fn=filter_fn, param_group=param_group @@ -1041,9 +1095,12 @@ def _multi_thread_sync_for_tp_num_mapping_gt_1( sorted_send_actors_stage2 = list(actor_mappings_stage2.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage2, actor_mappings_stage2) group_name = self.group_name + "_intra_comm" + logger.info(f"2 debug param sync:start to sync {group_name} param_group: {param_group}") + os.environ["log_debug"] = "1" self.sync_broadcast_multi_threads( sorted_send_actors_stage2, actor_mappings_stage2, max_workers, requires_grad, group_name=group_name, stage2=True, filter_fn=filter_fn, param_group=param_group) + os.environ["log_debug"] = "0" def _multi_thread_sync_for_tp_num_mapping_eq_1( self, send_actors_list:List, actor_mappings_list:List, @@ -1220,8 +1277,8 @@ def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None): return - #assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ - # f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" + assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ + f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: @@ -1408,6 +1465,7 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): requires_grad=requires_grad, filter_fn=self.routed_experts_filter) + logger.info(f"debug param sync: complete to alltoall router experts.") # sync everything to inference model if self.tp_num_mapping == 1: send_actors_list = [self.sorted_send_actors] @@ -1420,6 +1478,8 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): param_group="default" ) elif self.tp_num_mapping > 1: + logger.info(f"debug param sync: start to sync other weights.") + send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] self._multi_thread_sync_for_tp_num_mapping_gt_1( @@ -1429,6 +1489,7 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): filter_fn=None, param_group="default" ) + logger.info(f"debug param sync: complete to sync other weights.") else: raise NotImplementedError( f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})" diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index c4630cf8..a8f3257a 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -303,7 +303,7 @@ class RuntimeConfig(BaseConfig): #: profiler dir profiler_dir: str = None #: coalesce_buffer size in mb - coalesced_buffer_mb: int = 100 + coalesced_buffer_mb: int = 1024 #: concurrent parameter sync concurrent_comm: bool = True #: parameter sync communication type, broadcast/p2p From 88d86c863dc98fd605052870b88529cb28d573f9 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sat, 25 Jan 2025 18:17:35 +0800 Subject: [PATCH 11/19] param sync with multi-threads. --- chatlearn/models/base_module.py | 17 ++- chatlearn/synchronizer/parameter_sync.py | 140 ++++++++++++++++++----- 2 files changed, 120 insertions(+), 37 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index e910c00f..3a085146 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -843,7 +843,7 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr parameters_to_sync = self._parameters_to_recv[to_rank] else: parameters_to_sync = self._parameters_to_send - self._logger.info(f"stage2 need to sync params: {len(parameters_to_sync[0])}") + # self._logger.info(f"stage2 need to sync params: {len(parameters_to_sync[0])}") else: del self._sync_buffer self._sync_buffer = defaultdict(list) @@ -898,7 +898,7 @@ def tensor_generator(): yield param_data, buffer_num bucket_generator = bucket_tensors_two_stage_generator( - tensor_generator, bucket_size_mb=1024,#self.runtime_args.coalesced_buffer_mb, + tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb, stage2=stage2, tensor_changed=tensor_changed and not stage2 ) dense_bucket_num = 0 @@ -907,8 +907,8 @@ def tensor_generator(): for bucket_or_tensor, is_dense in bucket_generator: if is_dense: index = 0 if stage2 else (to_rank % self.tp_num_mapping) - if stage2: - self._logger.info(f"stage2 bucket: {len(bucket_or_tensor)} count: {count}") + # if stage2: + # self._logger.info(f"stage2 bucket: {len(bucket_or_tensor)} count: {count}") all_buffers = coalesced_comm_dense_two_stage( bucket_or_tensor, col.broadcast, rank, extra_args=(src_rank, group_name), tensor_changed=tensor_changed, @@ -922,13 +922,18 @@ def tensor_generator(): self._sync_buffer[key] += cpu_value del all_buffers count += len(bucket_or_tensor) - if stage2: - self._logger.info(f"finished stage2 bucket_or_tensor: {len(bucket_or_tensor)} count: {count}") + # if stage2: + # self._logger.info(f"finished stage2 bucket_or_tensor: {len(bucket_or_tensor)} count: {count}") dense_bucket_num += 1 else: col.broadcast(bucket_or_tensor, src_rank, group_name) sparse_bucket_num += 1 + if stage2: + self._logger.info(f"debug finished stage2 comm") + else: + self._logger.info(f"debug finished stage1 comm") + debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger) self.empty_cache() diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 17d084ac..c9eb989b 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -15,11 +15,12 @@ """Sync parameters""" import os +import time import concurrent.futures import traceback from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from itertools import cycle +from itertools import cycle, permutations, combinations from typing import List, Dict import torch @@ -858,6 +859,48 @@ def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): assert len(send_recv_actor_mappings) == len(sorted_send_actors) return sorted_send_actors + def sync_broadcast_second_stage(self, group_name, thread_group, send_recv_actor_mappings, requires_grad=None, filter_fn=None, param_group="default"): + actor_groups_to_sync = [] + combs = combinations(thread_group, 2) + for comb in combs: + perms = permutations(comb) + for perm in perms: + actor_groups_to_sync.append(perm) + actor_groups_to_sync_list = [] + for actor_group in actor_groups_to_sync: + if not actor_groups_to_sync_list: + actor_groups_to_sync_list.append([actor_group]) + for actor_groups in actor_groups_to_sync_list: + # print(f"debug actor_groups: {actor_groups}") + for group in actor_groups: + if group[0] in actor_group or group[1] in actor_group: + continue + actor_groups.append(actor_group) + for actor_group_list in actor_groups_to_sync_list: + log_rank = [] + for actor_group in actor_group_list: + send_actor, recv_actor = actor_group + log_rank.append((self.actor2rank[send_actor], self.actor2rank[recv_actor])) + logger.info(f"debug actor_group_list: {len(actor_group_list)} {log_rank}") + + with ThreadPoolExecutor(max_workers=len(actor_group_list)) as executor: + futures = [] + for idx, actor_group in enumerate(actor_group_list): + send_actor, recv_actor = actor_group + group_name_with_idx = f"{group_name}_{idx}" + actor_groups, finalized_group_name = self.create_broadcast_group( + send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group + ) + futures.append(executor.submit( + self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, True, filter_fn, param_group)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) + def sync_broadcast_multi_threads( self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None, group_name=None, stage2=False, filter_fn=None, param_group="default"): @@ -887,45 +930,64 @@ def sync_broadcast_multi_threads( # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") self.sync_broadcast_two_stage(actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group) # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]} finished.") - else: if stage2: - message = "second_stage" - else: - message = "first_stage" - max_workers = len(sorted_send_actors) - logger.info(f"Use {max_workers} workers for {message} broadcasting.") - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] + thread_group_dict = {} + thread_group_list = [] for send_actor in sorted_send_actors: + if send_actor in thread_group_dict: + continue recv_actors = send_recv_actor_mappings[send_actor] - logger.info(f"sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") - if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: - if stage2: - for idx, recv_actor in enumerate(recv_actors): - group_name_with_idx = f"{group_name}_{idx}" - actor_groups, finalized_group_name = self.create_broadcast_group( - send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group - ) - futures.append(executor.submit( - self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group - )) + thread_group_list.append([send_actor] + recv_actors) + for actor in thread_group_list[-1]: + thread_group_dict[actor] = True + del thread_group_dict + max_workers = len(thread_group_list) + logger.info(f"Use {max_workers} workers for second_stage broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for group_idx, thread_group in enumerate(thread_group_list): + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + self.sync_broadcast_second_stage( + f"{group_name}_{group_idx}", + thread_group, + send_recv_actor_mappings, + requires_grad, + filter_fn, + param_group + ) else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) + else: + logger.info(f"Use {max_workers} workers for first_stage broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for send_actor in sorted_send_actors: + recv_actors = send_recv_actor_mappings[send_actor] + logger.info(f"sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, recv_actors, group_name=group_name, param_group=param_group ) futures.append(executor.submit( self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group )) - else: - raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - traceback.print_exc() - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) def sync_allgather_multi_threads( self, send_actors, max_workers=1, requires_grad=None, @@ -957,6 +1019,7 @@ def sync_alltoall_multi_threads( self.sync_alltoall(actor_groups, requires_grad, filter_fn=filter_fn) logger.info(f"2 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") else: + max_workers = len(send_actors_to_alltoall_routed_experts) logger.info(f"Use {max_workers} workers for alltoall multiprocessing.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] @@ -1050,7 +1113,7 @@ def validate_sync_results_parallel(self, actor_mappings_list:List, requires_grad def _calculate_max_workers(self, sorted_send_actors, actor_mappings=None): max_workers = get_args().runtime_args.param_sync_max_workers if max_workers is None: - max_workers = max(self.src_model.total_gpu // 8, 1) + max_workers = max(self.src_model.total_gpu // self.num_src_pipeline_stage, 1) if max_workers == -1: if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: max_workers = len(sorted_send_actors) @@ -1083,6 +1146,7 @@ def _multi_thread_sync_for_tp_num_mapping_gt_1( actor_mappings_stage2 = actor_mappings[1] # stage 1 + s_time = time.time() sorted_send_actors_stage1 = list(actor_mappings_stage1.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage1, actor_mappings_stage1) group_name = self.group_name + "_inter_comm" @@ -1091,7 +1155,10 @@ def _multi_thread_sync_for_tp_num_mapping_gt_1( sorted_send_actors_stage1, actor_mappings_stage1, max_workers, requires_grad, group_name=group_name, stage2=False, filter_fn=filter_fn, param_group=param_group ) + e_time = time.time() + logger.info(f"debug param sync cost stage1: {e_time - s_time}") # stage 2 + s_time = time.time() sorted_send_actors_stage2 = list(actor_mappings_stage2.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage2, actor_mappings_stage2) group_name = self.group_name + "_intra_comm" @@ -1101,6 +1168,8 @@ def _multi_thread_sync_for_tp_num_mapping_gt_1( sorted_send_actors_stage2, actor_mappings_stage2, max_workers, requires_grad, group_name=group_name, stage2=True, filter_fn=filter_fn, param_group=param_group) os.environ["log_debug"] = "0" + e_time = time.time() + logger.info(f"debug param sync cost stage2: {e_time - s_time}") def _multi_thread_sync_for_tp_num_mapping_eq_1( self, send_actors_list:List, actor_mappings_list:List, @@ -1448,6 +1517,7 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): if self.concurrent_comm: assert self.dst_model.use_vllm_backend + s_time = time.time() max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts) if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER: # allgather routed experts only @@ -1464,10 +1534,13 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): max_workers=max_workers, requires_grad=requires_grad, filter_fn=self.routed_experts_filter) + e_time = time.time() - logger.info(f"debug param sync: complete to alltoall router experts.") + logger.info(f"debug param sync: complete to alltoall router experts. ") + logger.info(f"debug param sync cost alltoall {e_time-s_time}") # sync everything to inference model if self.tp_num_mapping == 1: + s_time = time.time() send_actors_list = [self.sorted_send_actors] actor_mappings_list = [self.send_recv_actor_mappings] self._multi_thread_sync_for_tp_num_mapping_eq_1( @@ -1477,7 +1550,11 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): filter_fn=None, param_group="default" ) + e_time = time.time() + logger.info(f"debug param sync: complete to alltoall router experts") + elif self.tp_num_mapping > 1: + s_time = time.time() logger.info(f"debug param sync: start to sync other weights.") send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] @@ -1489,6 +1566,7 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): filter_fn=None, param_group="default" ) + e_time = time.time() logger.info(f"debug param sync: complete to sync other weights.") else: raise NotImplementedError( From ec77aa22771cf49000fa1bcc52be1e545e13f7c9 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sat, 25 Jan 2025 22:33:45 +0800 Subject: [PATCH 12/19] fix multi-thread. --- chatlearn/synchronizer/parameter_sync.py | 128 ++++++++++------------- 1 file changed, 56 insertions(+), 72 deletions(-) diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index c9eb989b..453b1946 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -603,7 +603,7 @@ def sync_broadcast_two_stage(self, actors, group_name, requires_grad=None, stage if stage2: s_names, d_names = self.set_sync_param_names_stage2(send_actor, recv_actor, self.actor2rank[recv_actor], requires_grad, filter_fn, param_group) # for s_name, d_name in zip(s_names, d_names): - logger.info(f"stage 2 sync: {len(s_names)} -> {len(d_names)}") + # logger.info(f"stage 2 sync: {len(s_names)} -> {len(d_names)}") else: self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group) pipe_stage = self.get_actor_pipe_rank(send_actor) @@ -859,47 +859,27 @@ def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): assert len(send_recv_actor_mappings) == len(sorted_send_actors) return sorted_send_actors - def sync_broadcast_second_stage(self, group_name, thread_group, send_recv_actor_mappings, requires_grad=None, filter_fn=None, param_group="default"): - actor_groups_to_sync = [] - combs = combinations(thread_group, 2) - for comb in combs: - perms = permutations(comb) - for perm in perms: - actor_groups_to_sync.append(perm) - actor_groups_to_sync_list = [] - for actor_group in actor_groups_to_sync: - if not actor_groups_to_sync_list: - actor_groups_to_sync_list.append([actor_group]) - for actor_groups in actor_groups_to_sync_list: - # print(f"debug actor_groups: {actor_groups}") - for group in actor_groups: - if group[0] in actor_group or group[1] in actor_group: - continue - actor_groups.append(actor_group) - for actor_group_list in actor_groups_to_sync_list: - log_rank = [] - for actor_group in actor_group_list: + def sync_broadcast_second_stage(self, group_name, thread_group, requires_grad=None, filter_fn=None, param_group="default"): + max_workers = len(thread_group) + logger.info(f"debug thread_group {group_name}: {[(self.actor2rank[ele[0]], self.actor2rank[ele[1]]) for ele in thread_group]}") + logger.info(f"Use {max_workers} workers for second_stage broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for idx, actor_group in enumerate(thread_group): send_actor, recv_actor = actor_group - log_rank.append((self.actor2rank[send_actor], self.actor2rank[recv_actor])) - logger.info(f"debug actor_group_list: {len(actor_group_list)} {log_rank}") - - with ThreadPoolExecutor(max_workers=len(actor_group_list)) as executor: - futures = [] - for idx, actor_group in enumerate(actor_group_list): - send_actor, recv_actor = actor_group - group_name_with_idx = f"{group_name}_{idx}" - actor_groups, finalized_group_name = self.create_broadcast_group( - send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group - ) - futures.append(executor.submit( - self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, True, filter_fn, param_group)) - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - traceback.print_exc() - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + group_name_with_idx = f"{group_name}_{idx}" + actor_groups, finalized_group_name = self.create_broadcast_group( + send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group + ) + futures.append(executor.submit( + self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, True, filter_fn, param_group)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) def sync_broadcast_multi_threads( self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None, @@ -932,40 +912,44 @@ def sync_broadcast_multi_threads( # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]} finished.") else: if stage2: - thread_group_dict = {} - thread_group_list = [] + thread_group = [] for send_actor in sorted_send_actors: - if send_actor in thread_group_dict: - continue recv_actors = send_recv_actor_mappings[send_actor] - thread_group_list.append([send_actor] + recv_actors) - for actor in thread_group_list[-1]: - thread_group_dict[actor] = True - del thread_group_dict - max_workers = len(thread_group_list) - logger.info(f"Use {max_workers} workers for second_stage broadcasting.") - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for group_idx, thread_group in enumerate(thread_group_list): - if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: - self.sync_broadcast_second_stage( - f"{group_name}_{group_idx}", - thread_group, - send_recv_actor_mappings, - requires_grad, - filter_fn, - param_group - ) - else: - raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - traceback.print_exc() - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + for recv_actor in recv_actors: + thread_group.append((send_actor, recv_actor)) + actor_groups_to_sync = [] + for group in thread_group: + new_actor_group_flag = True + for idx, actor_groups in enumerate(actor_groups_to_sync): + in_actor_group = False + for jdx, actor_group in enumerate(actor_groups): + if group[0] == actor_group[0] or group[1] == actor_group[1]: + in_actor_group = True + if not in_actor_group: + new_actor_group_flag = False + actor_groups_to_sync[idx].append(group) + break + if new_actor_group_flag or not actor_groups_to_sync: + actor_groups_to_sync.append([group]) + log_list = [] + for thread_group in actor_groups_to_sync: + log_list.append([(self.actor2rank[ele[0]], self.actor2rank[ele[1]]) for ele in thread_group]) + + logger.info(f"debug actor_groups_to_sync {group_name}: {log_list}") + + for group_idx, actor_groups in enumerate(actor_groups_to_sync): + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + self.sync_broadcast_second_stage( + f"{group_name}_{group_idx}", + actor_groups, + requires_grad, + filter_fn, + param_group + ) + else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") else: + max_workers = len(sorted_send_actors) logger.info(f"Use {max_workers} workers for first_stage broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] From cce344b07a954c4714975a40023f7d7b376f7a82 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sun, 26 Jan 2025 14:00:17 +0800 Subject: [PATCH 13/19] fix --- chatlearn/synchronizer/parameter_sync.py | 33 ++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 453b1946..5309f632 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -859,10 +859,11 @@ def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): assert len(send_recv_actor_mappings) == len(sorted_send_actors) return sorted_send_actors - def sync_broadcast_second_stage(self, group_name, thread_group, requires_grad=None, filter_fn=None, param_group="default"): + def sync_broadcast_second_stage_internal(self, group_name, thread_group, requires_grad=None, filter_fn=None, param_group="default"): max_workers = len(thread_group) + max_workers = 8 logger.info(f"debug thread_group {group_name}: {[(self.actor2rank[ele[0]], self.actor2rank[ele[1]]) for ele in thread_group]}") - logger.info(f"Use {max_workers} workers for second_stage broadcasting.") + logger.info(f"Use {max_workers} workers for second_stage_internal broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for idx, actor_group in enumerate(thread_group): @@ -881,6 +882,34 @@ def sync_broadcast_second_stage(self, group_name, thread_group, requires_grad=No raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) + + def sync_broadcast_second_stage(self, group_name, thread_groups, requires_grad=None, filter_fn=None, param_group="default"): + + tp_size = self.num_dst_tensor_parallel + num_thread_groups = len(thread_groups) // tp_size + new_thread_groups = [thread_groups[tp_size*i:tp_size*(i+1)] for i in range(num_thread_groups)] + + max_workers = len(new_thread_groups) + + logger.info(f"Use {max_workers} workers for second_stage broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for idx, thread_group in enumerate(new_thread_groups): + # send_actor, recv_actor = actor_group + group_name_with_idx = f"{group_name}_{idx}" + # actor_groups, finalized_group_name = self.create_broadcast_group( + # send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group + # ) + futures.append(executor.submit( + self.sync_broadcast_two_stage_internal, group_name_with_idx, thread_group, requires_grad, filter_fn, param_group)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) + def sync_broadcast_multi_threads( self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None, group_name=None, stage2=False, filter_fn=None, param_group="default"): From 46c081d41df92afabe63d9a0a652de591462d8d8 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sun, 26 Jan 2025 14:05:51 +0800 Subject: [PATCH 14/19] fix func name. --- chatlearn/synchronizer/parameter_sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 5309f632..e682ba61 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -901,7 +901,7 @@ def sync_broadcast_second_stage(self, group_name, thread_groups, requires_grad=N # send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group # ) futures.append(executor.submit( - self.sync_broadcast_two_stage_internal, group_name_with_idx, thread_group, requires_grad, filter_fn, param_group)) + self.sync_broadcast_second_stage_internal, group_name_with_idx, thread_group, requires_grad, filter_fn, param_group)) for _future in concurrent.futures.as_completed(futures): try: _future.result() From 6e8a7d37bac307de297a0eb40bfa82a86f5e631d Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sun, 26 Jan 2025 17:23:15 +0800 Subject: [PATCH 15/19] speed up setup. --- chatlearn/runtime/dist_actor.py | 11 ++++++----- chatlearn/runtime/environment.py | 5 +++-- chatlearn/schedule/model_manager.py | 15 ++++++++------- chatlearn/synchronizer/parameter_sync.py | 2 +- 4 files changed, 18 insertions(+), 15 deletions(-) diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index 6819d879..ee2dceb6 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -307,7 +307,7 @@ def __init__(self): self.replicas = [] self.name = None self.rank_to_actors = {} - self.register_serial_func() + # self.register_serial_func() self.register_func() self._is_colocate = False self._colocate_models = [] @@ -362,10 +362,11 @@ def get_actor(self, rank): if rank in dist_actor.rank_to_actors: return dist_actor.rank_to_actors[rank] - def register_serial_func(self): - for func_name in ["init"]: - dist_call = partial(self.call_replica_serial_func, func_name) - setattr(self, func_name, dist_call) + def init(self): + refs = [] + for dist_actor in self.replicas: + refs.append(dist_actor.init()) + future.get(refs) def register_func(self): for func_name in ["model_setup", diff --git a/chatlearn/runtime/environment.py b/chatlearn/runtime/environment.py index a0c95ffd..67496a01 100644 --- a/chatlearn/runtime/environment.py +++ b/chatlearn/runtime/environment.py @@ -75,9 +75,10 @@ def setup(self): self._padding_config.update(config) if isinstance(model.model, VLLMModuleV2): + refs = [] for replica in model_node.model.replicas: - ret = replica.vllm_engine.setup_vllm.remote(replica.all_actors) - future.wait(ret) + refs.append(replica.vllm_engine.setup_vllm.remote(replica.all_actors)) + future.wait(refs) @property def sample_per_episode(self): diff --git a/chatlearn/schedule/model_manager.py b/chatlearn/schedule/model_manager.py index bc488e4f..fe775916 100644 --- a/chatlearn/schedule/model_manager.py +++ b/chatlearn/schedule/model_manager.py @@ -175,18 +175,19 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False): sync_group: ParameterSyncGroup = sync_group src_model, dst_model = sync_group.src_model, sync_group.dst_model + onload_refs = [] refs = src_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) - future.wait(refs) + onload_refs.append(refs) refs = dst_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) - future.wait(refs) + onload_refs.append(refs) + future.wait(onload_refs) sync_group.sync(requires_grad, validate) - refs = src_model.offload() - future.wait(refs) - refs = dst_model.offload() - future.wait(refs) - + offload_refs = [] + offload_refs.append(src_model.offload()) + offload_refs.append(dst_model.offload()) + future.wait(offload_refs) def set_func_decorator(self, model): if is_decorated(model.name): return diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index e682ba61..ad037db2 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -861,7 +861,7 @@ def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): def sync_broadcast_second_stage_internal(self, group_name, thread_group, requires_grad=None, filter_fn=None, param_group="default"): max_workers = len(thread_group) - max_workers = 8 + max_workers = min(8, max_workers) logger.info(f"debug thread_group {group_name}: {[(self.actor2rank[ele[0]], self.actor2rank[ele[1]]) for ele in thread_group]}") logger.info(f"Use {max_workers} workers for second_stage_internal broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: From e0ab9383f2c34d7e0a9c2781cd627a46e1ae703b Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sun, 26 Jan 2025 17:55:49 +0800 Subject: [PATCH 16/19] fix --- chatlearn/synchronizer/parameter_sync.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index ad037db2..280d0d83 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -890,6 +890,7 @@ def sync_broadcast_second_stage(self, group_name, thread_groups, requires_grad=N new_thread_groups = [thread_groups[tp_size*i:tp_size*(i+1)] for i in range(num_thread_groups)] max_workers = len(new_thread_groups) + max_workers = max(max_workers, 1) logger.info(f"Use {max_workers} workers for second_stage broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: From 4afa743ed675d08c3c24e0d398a7dd5546892b4f Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Sun, 26 Jan 2025 21:22:30 +0800 Subject: [PATCH 17/19] fix tp4ep2pp1. --- chatlearn/runtime/decorator.py | 1 + chatlearn/runtime/dist_actor.py | 11 +++++------ chatlearn/runtime/environment.py | 5 ++--- chatlearn/schedule/model_manager.py | 15 +++++++-------- chatlearn/synchronizer/parameter_sync.py | 5 +++-- 5 files changed, 18 insertions(+), 19 deletions(-) diff --git a/chatlearn/runtime/decorator.py b/chatlearn/runtime/decorator.py index 07ceff75..7034126c 100644 --- a/chatlearn/runtime/decorator.py +++ b/chatlearn/runtime/decorator.py @@ -168,6 +168,7 @@ def get_kwarg(key): # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ or isinstance(self, VLLMModuleV2): + # print(f"replica_id_{self.replica_id} results: {len(results)} {results} input_data: {input_data}") final_results = concat_along_batch(results) else: if 'iteration' in inspect.signature(func).parameters: diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index ee2dceb6..6819d879 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -307,7 +307,7 @@ def __init__(self): self.replicas = [] self.name = None self.rank_to_actors = {} - # self.register_serial_func() + self.register_serial_func() self.register_func() self._is_colocate = False self._colocate_models = [] @@ -362,11 +362,10 @@ def get_actor(self, rank): if rank in dist_actor.rank_to_actors: return dist_actor.rank_to_actors[rank] - def init(self): - refs = [] - for dist_actor in self.replicas: - refs.append(dist_actor.init()) - future.get(refs) + def register_serial_func(self): + for func_name in ["init"]: + dist_call = partial(self.call_replica_serial_func, func_name) + setattr(self, func_name, dist_call) def register_func(self): for func_name in ["model_setup", diff --git a/chatlearn/runtime/environment.py b/chatlearn/runtime/environment.py index 67496a01..a0c95ffd 100644 --- a/chatlearn/runtime/environment.py +++ b/chatlearn/runtime/environment.py @@ -75,10 +75,9 @@ def setup(self): self._padding_config.update(config) if isinstance(model.model, VLLMModuleV2): - refs = [] for replica in model_node.model.replicas: - refs.append(replica.vllm_engine.setup_vllm.remote(replica.all_actors)) - future.wait(refs) + ret = replica.vllm_engine.setup_vllm.remote(replica.all_actors) + future.wait(ret) @property def sample_per_episode(self): diff --git a/chatlearn/schedule/model_manager.py b/chatlearn/schedule/model_manager.py index fe775916..bc488e4f 100644 --- a/chatlearn/schedule/model_manager.py +++ b/chatlearn/schedule/model_manager.py @@ -175,19 +175,18 @@ def sync_parameters(self, episode_offset=0, requires_grad=None, validate=False): sync_group: ParameterSyncGroup = sync_group src_model, dst_model = sync_group.src_model, sync_group.dst_model - onload_refs = [] refs = src_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) - onload_refs.append(refs) + future.wait(refs) refs = dst_model.onload(to_build_grad_buffers=False, to_onload_main_weights=False, to_onload_optimizer_states=False) - onload_refs.append(refs) - future.wait(onload_refs) + future.wait(refs) sync_group.sync(requires_grad, validate) - offload_refs = [] - offload_refs.append(src_model.offload()) - offload_refs.append(dst_model.offload()) - future.wait(offload_refs) + refs = src_model.offload() + future.wait(refs) + refs = dst_model.offload() + future.wait(refs) + def set_func_decorator(self, model): if is_decorated(model.name): return diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 280d0d83..649c99b5 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -889,8 +889,9 @@ def sync_broadcast_second_stage(self, group_name, thread_groups, requires_grad=N num_thread_groups = len(thread_groups) // tp_size new_thread_groups = [thread_groups[tp_size*i:tp_size*(i+1)] for i in range(num_thread_groups)] + if not new_thread_groups: + new_thread_groups = [thread_groups] max_workers = len(new_thread_groups) - max_workers = max(max_workers, 1) logger.info(f"Use {max_workers} workers for second_stage broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -953,7 +954,7 @@ def sync_broadcast_multi_threads( for idx, actor_groups in enumerate(actor_groups_to_sync): in_actor_group = False for jdx, actor_group in enumerate(actor_groups): - if group[0] == actor_group[0] or group[1] == actor_group[1]: + if group[0] in actor_group or group[1] in actor_group: in_actor_group = True if not in_actor_group: new_actor_group_flag = False From c3a45fc86b14ebde07aaa84900ff05124bffe21f Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 30 Jan 2025 23:20:13 +0800 Subject: [PATCH 18/19] fix broadcast for moe experts. --- chatlearn/models/base_module.py | 20 ++++++ chatlearn/synchronizer/megatron_vllm.py | 80 +++++++++++++++++++----- chatlearn/synchronizer/parameter_sync.py | 46 +++++++++++--- 3 files changed, 121 insertions(+), 25 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 3a085146..d4feee06 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -18,6 +18,7 @@ from itertools import cycle import math import os +import torch import ray import ray.util.collective as col @@ -933,6 +934,25 @@ def tensor_generator(): self._logger.info(f"debug finished stage2 comm") else: self._logger.info(f"debug finished stage1 comm") + + check_rank = self.tensor_parallel_rank() + if self.tensor_parallel_rank() == check_rank and stage2:# and check_rank not in [0, 1, 2, 3]: + if not isinstance(self.model, list): + model = [self.model] + else: + model = self.model + for item in model[0].named_parameters(): + name, param = item + if "layers.0" in name: + print(f"debug output param {name} {param.shape}") + offset = 4 + num_prints = param.shape[0] // offset + with open(f"/workspace/code/cmd/moelite_scripts/new/tp2ep4pp1_{check_rank}_{name}.txt", "a+") as file: + for i in range(num_prints): + start = offset * i + end = start + offset + tensor_to_print = param[start:end] + file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger) diff --git a/chatlearn/synchronizer/megatron_vllm.py b/chatlearn/synchronizer/megatron_vllm.py index 991d538b..b06b6ca7 100644 --- a/chatlearn/synchronizer/megatron_vllm.py +++ b/chatlearn/synchronizer/megatron_vllm.py @@ -33,6 +33,7 @@ class MegatronVllmSync(BaseSync): def __init__(self, src_model, dst_model): super().__init__(src_model, dst_model) self.src_module_args = src_model.module_args + self.dst_module_args = dst_model.module_args self.is_parameter_changed = True @abstractmethod @@ -260,12 +261,22 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group, mod ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"] hep_size = tp_size * ep_size moe_num_experts = self.src_module_args.args_dict["moe_num_experts"] + local_num_experts = moe_num_experts // hep_size hidden_size = self.src_module_args.args_dict["hidden_size"] if "dense_h_to_4h" in op_name: # w13_weight # regroup among difference tp slices param = params_to_sync.view((moe_num_experts, -1, hidden_size)) + # if "layers.0" in name: + # offset = 4 + # num_prints = param.shape[0] // offset + # with open(f"/workspace/code/cmd/moelite_scripts/chatlearn_w13_{rank}_before_tp4ep2.txt", "a+") as file: + # for i in range(num_prints): + # start = offset * i + # end = start + offset + # tensor_to_print = param[start:end] + # file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") param = param.reshape((local_num_experts * 2, -1, hidden_size)) params = list(param.chunk(hep_size, dim=1)) # reorder w1 and w3 @@ -284,15 +295,30 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group, mod torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device) for i in range(hep_size) ] - # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: - # file.write(f"3.1 debug alltoall rank: {rank}/{world_size} vs {torch.distributed.get_rank(group=comm_group)}/{torch.distributed.get_world_size(group=comm_group)} in comm group {id(comm_group)} {name}" + "\n") - + # if rank == 0 and "layers.0" in name: + # offset = 4 + # for idx, ele in enumerate(params_list): + # num_prints = ele.shape[0] // offset + # with open(f"/workspace/code/cmd/moelite_scripts/chatlearn_w13_slice_{rank}_tp4ep2.txt", "a+") as file: + # for i in range(num_prints): + # start = offset * i + # end = start + offset + # tensor_to_print = ele[start:end] + # file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") + torch.distributed.all_to_all(output, params_list, group=comm_group) - # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: - # file.write(f"3.2 debug alltoall rank: {rank}/{world_size} vs {torch.distributed.get_rank(group=comm_group)}/{torch.distributed.get_world_size(group=comm_group)} in comm group {id(comm_group)} {name}" + "\n") - del params_list params_to_sync = torch.cat(output, dim=0).contiguous() + # if "layers.0" in name: + # offset = 4 + # num_prints = params_to_sync.shape[0] // offset + # with open(f"/workspace/code/cmd/moelite_scripts/chatlearn_w13_{rank}_after_tp4ep2.txt", "a+") as file: + # for i in range(num_prints): + # start = offset * i + # end = start + offset + # tensor_to_print = params_to_sync[start:end] + # file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") + del output else: # w2_weight @@ -353,8 +379,9 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition): if "attention.query_key_value" in name or \ "self_attention.query_key_value" in name or \ "self_attention.linear_qkv" in name: - tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] - heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size + src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] + dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"] + heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"] param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:] @@ -372,23 +399,42 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition): param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) del param_data_list else: - _num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size \ - if self.src_module_args.args_dict["group_query_attention"] else heads - if to_fix_qkv_ordering_dict is not None or _num_query_groups == 1: + if self.src_module_args.args_dict["group_query_attention"]: + num_query_groups = self.src_module_args.args_dict["num_query_groups"] + assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], ( + f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of " + f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config." + ) + src_num_query_groups_per_replica = num_query_groups // src_tp_size + if dst_tp_size >= num_query_groups: + num_dst_kv_head_replicas = dst_tp_size // num_query_groups + else: + num_dst_kv_head_replicas = 1 + else: + src_num_query_groups_per_replica = heads + num_dst_kv_head_replicas = 1 + + if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1: if len(param_data_shape) == 1: - param_data = param_data.view((heads + 2 * _num_query_groups, hidden_size_per_head)) + param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head)) else: param_data = param_data.view( - (heads + 2 * _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])) + (heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])) param_data_list = [] head_offset = heads // tp_divition for idx in range(tp_divition): q_start = idx * head_offset q_end = q_start + head_offset - k_start = (heads + idx) if _num_query_groups // tp_divition else heads - k_end = k_start + 1 - v_start = k_start + _num_query_groups - v_end = v_start + 1 + if num_dst_kv_head_replicas == 1: + k_start = (heads + idx) if src_num_query_groups_per_replica // tp_divition else heads + k_end = k_start + 1 + v_start = k_start + src_num_query_groups_per_replica + v_end = v_start + 1 + else: + k_start = heads + idx // num_dst_kv_head_replicas + k_end = k_start + 1 + v_start = k_start + src_num_query_groups_per_replica + v_end = v_start + 1 q_proj = param_data[q_start:q_end].contiguous() k_proj = param_data[k_start:k_end].contiguous() diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 649c99b5..840961b3 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -1555,6 +1555,7 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): logger.info(f"debug param sync cost alltoall {e_time-s_time}") # sync everything to inference model if self.tp_num_mapping == 1: + logger.info(f"debug self.tp_num_mapping: {self.tp_num_mapping}") s_time = time.time() send_actors_list = [self.sorted_send_actors] actor_mappings_list = [self.send_recv_actor_mappings] @@ -1569,17 +1570,46 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): logger.info(f"debug param sync: complete to alltoall router experts") elif self.tp_num_mapping > 1: + logger.info(f"debug self.tp_num_mapping: {self.tp_num_mapping}") s_time = time.time() logger.info(f"debug param sync: start to sync other weights.") - send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] - actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] - self._multi_thread_sync_for_tp_num_mapping_gt_1( - send_actors_list, - actor_mappings_list, - requires_grad=requires_grad, - filter_fn=None, - param_group="default" + # send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] + # actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] + # self._multi_thread_sync_for_tp_num_mapping_gt_1( + # send_actors_list, + # actor_mappings_list, + # requires_grad=requires_grad, + # filter_fn=None, + # param_group="default" + # ) + # First, synchronize routed experts. + self._synchronize_routed_experts(requires_grad=requires_grad, validate=validate) + + self.clear_cache( + sorted_send_actors_list = [ + self.send_actors_to_regroup_routed_experts, + self.sorted_send_actors_for_routed_experts + ], + rank_mapping_list=[ + self.send_recv_actor_mappings_for_routed_experts + ] + ) + + # Then, synchronize parameters except routed experts + self._synchronize_params_except_routed_experts(requires_grad=requires_grad, validate=validate) + + self.reset_synchronizer() + + self.clear_cache( + sorted_send_actors_list = [ + self.sorted_send_actors, + self.sorted_send_actors_stage2, + ], + rank_mapping_list = [ + self.send_recv_actor_mappings, + self.send_recv_actor_mappings_stage2 + ] ) e_time = time.time() logger.info(f"debug param sync: complete to sync other weights.") From 36bfb52ccac6452e20487f11fca08caf52d0f78c Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Fri, 31 Jan 2025 11:04:12 +0800 Subject: [PATCH 19/19] fix broadcast for moe experts. --- chatlearn/models/base_module.py | 2 +- chatlearn/synchronizer/parameter_sync.py | 35 +++++++++++++++++++----- 2 files changed, 29 insertions(+), 8 deletions(-) diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index d4feee06..a756fa37 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -936,7 +936,7 @@ def tensor_generator(): self._logger.info(f"debug finished stage1 comm") check_rank = self.tensor_parallel_rank() - if self.tensor_parallel_rank() == check_rank and stage2:# and check_rank not in [0, 1, 2, 3]: + if False:#self.tensor_parallel_rank() == check_rank and stage2:# and check_rank not in [0, 1, 2, 3]: if not isinstance(self.model, list): model = [self.model] else: diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 840961b3..d922d603 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -337,6 +337,7 @@ def split_ranks_by_tp_and_ep_size(ranks, j = i // pipe_map_interval for src_rank, dst_rank in zip(src_tp_group, dst_replica_ranks_group[j]): add_recv_actor_fn(src_rank, dst_rank) + logger.info(f"20250131 Sending from {src_rank} to {dst_rank}") # pylint: disable=unused-argument def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): @@ -464,8 +465,8 @@ def p2p_pair_grouping(tuples): for tuples in dst_tp_group: p2p_pair_grouping(tuples) - logger.debug(f"comm pair_list : {pair_list}") - logger.debug(f"comm p2p_list : {p2p_list}") + logger.info(f"comm pair_list : {pair_list}") + logger.info(f"comm p2p_list : {p2p_list}") def _clear_sync_send_recv_parameters(self, rank_mappings:List): if len(rank_mappings) == 0: @@ -1195,6 +1196,7 @@ def _multi_thread_sync_for_tp_num_mapping_eq_1( actor_mappings = actor_mappings_list[0] sorted_send_actors = self.sort_send_actors(actor_mappings, send_actors) + logger.info(f"20250131 sorted_send_actors: {sorted_send_actors}") max_workers = self._calculate_max_workers(sorted_send_actors, actor_mappings) with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -1327,8 +1329,18 @@ def setup_rank_mapping(self): else: self.build_rank_mapping_for_ep() elif self.tp_num_mapping > 1: - self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors - self.build_rank_mapping_two_stage() + if self.hep_num_mapping == 1: + logger.info(f"20250131 build ranking mapping for ep") + self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors + # self.send_actors_to_regroup_routed_experts = self.send_actors_to_regroup_routed_experts[0] + # self.sorted_send_actors_for_routed_experts = self.send_actors_to_regroup_routed_experts + logger.info(f"20250131 build ranking mapping for routed expert") + self.build_rank_mapping_for_routed_experts() + logger.info(f"20250131 build ranking mapping for params except routed expert") + self.build_rank_mapping_for_params_except_routed_expert() + else: + self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors + self.build_rank_mapping_two_stage() else: raise NotImplementedError( f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})" @@ -1419,7 +1431,7 @@ def split_ranks_by_ep_and_tp_size(ranks, (src_replica_ranks2offset[tuple(src_replica_ranks)] + len_dst_div_src) % len(src_ep_and_tp_group) ) - if self._debug: + if True:#self._debug: def debug_msg_for_actor_mappings(actor_mapping): if actor_mapping is None: return @@ -1431,12 +1443,14 @@ def debug_msg_for_actor_mappings(actor_mapping): debug_msg_for_actor_mappings(self.send_recv_actor_mappings) debug_msg_for_actor_mappings(self.send_recv_actor_mappings_for_routed_experts) + count = 0 for regroup_actors in self.send_actors_to_regroup_routed_experts: + count += 1 cat_str = "_".join(str(self.actor2rank[actor]) for actor in regroup_actors) - logger.debug(f"{self._comm_type_to_regroup_routed_experts} actors: {cat_str}") + logger.info(f"{self._comm_type_to_regroup_routed_experts} actors_{count}: {cat_str}") for k, v_list in self.send_recv_actor_mappings.items(): for v in v_list: - logger.debug(f"send_recv_actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}") + logger.info(f"send_recv_actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}") def add_recv_actor_for_routed_experts(self, src_rank, dst_rank): src_actor = self.src_model.get_actor(src_rank) @@ -1641,8 +1655,15 @@ def _synchronize_routed_experts(self, requires_grad=None, validate=False): send_actors_list : List = [] actor_mappings_list : List = [] if self.concurrent_comm: + send_actors_list = [self.sorted_send_actors_for_routed_experts] actor_mappings_list = [self.send_recv_actor_mappings_for_routed_experts] + + logger.info(f"send_actors_list: {send_actors_list}") + logger.info(f"actor_mappings_list: {actor_mappings_list}") + # import pdb;pdb.set_trace() + # logger.info(f"sync routed experts from {[self.actor2rank[ele] for ele in send_actors_list[0]]}") + self._multi_thread_sync_for_tp_num_mapping_eq_1( send_actors_list, actor_mappings_list,