diff --git a/chatlearn/models/base_module.py b/chatlearn/models/base_module.py index 8e571e1c..a756fa37 100644 --- a/chatlearn/models/base_module.py +++ b/chatlearn/models/base_module.py @@ -18,6 +18,7 @@ from itertools import cycle import math import os +import torch import ray import ray.util.collective as col @@ -751,12 +752,26 @@ def send_recv_parameter(self, rank, group_name, func, pipe_stage=0): def alltoall_routed_expert_parameter(self, pipe_stage=0): assert self._synchronizer is not None + import torch + + comm_group = self.tensor_and_expert_parallel_group() + rank = torch.distributed.get_rank(group=comm_group) + world_size = torch.distributed.get_world_size(group=comm_group) + # with open(f"/workspace/code/cmd/moelite_scripts/{self.name}_{rank}_{world_size}.txt", "a+") as file: + # file.write(f"debug alltoall rank: {rank} in comm group {id(comm_group)}, num_params: {len(self._parameters_to_sync[pipe_stage])}" + "\n") + # breakpoint() for name, param in self._parameters_to_sync[pipe_stage]: param, state = self._synchronizer.alltoall_routed_experts( name, param, - self.tensor_and_expert_parallel_group() + self.tensor_and_expert_parallel_group(), + self.name, + rank, + world_size ) + + # self._logger.info(f"debug {name} {param.shape} state: {state}") + # state = True if state: self._expert_sync_buffer.pop(name, "Not Found.") self._expert_sync_buffer[name] = param @@ -829,6 +844,7 @@ def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, gr parameters_to_sync = self._parameters_to_recv[to_rank] else: parameters_to_sync = self._parameters_to_send + # self._logger.info(f"stage2 need to sync params: {len(parameters_to_sync[0])}") else: del self._sync_buffer self._sync_buffer = defaultdict(list) @@ -888,9 +904,12 @@ def tensor_generator(): ) dense_bucket_num = 0 sparse_bucket_num = 0 + count = 0 for bucket_or_tensor, is_dense in bucket_generator: if is_dense: index = 0 if stage2 else (to_rank % self.tp_num_mapping) + # if stage2: + # self._logger.info(f"stage2 bucket: {len(bucket_or_tensor)} count: {count}") all_buffers = coalesced_comm_dense_two_stage( bucket_or_tensor, col.broadcast, rank, extra_args=(src_rank, group_name), tensor_changed=tensor_changed, @@ -903,11 +922,38 @@ def tensor_generator(): del value self._sync_buffer[key] += cpu_value del all_buffers + count += len(bucket_or_tensor) + # if stage2: + # self._logger.info(f"finished stage2 bucket_or_tensor: {len(bucket_or_tensor)} count: {count}") dense_bucket_num += 1 else: col.broadcast(bucket_or_tensor, src_rank, group_name) sparse_bucket_num += 1 + if stage2: + self._logger.info(f"debug finished stage2 comm") + else: + self._logger.info(f"debug finished stage1 comm") + + check_rank = self.tensor_parallel_rank() + if False:#self.tensor_parallel_rank() == check_rank and stage2:# and check_rank not in [0, 1, 2, 3]: + if not isinstance(self.model, list): + model = [self.model] + else: + model = self.model + for item in model[0].named_parameters(): + name, param = item + if "layers.0" in name: + print(f"debug output param {name} {param.shape}") + offset = 4 + num_prints = param.shape[0] // offset + with open(f"/workspace/code/cmd/moelite_scripts/new/tp2ep4pp1_{check_rank}_{name}.txt", "a+") as file: + for i in range(num_prints): + start = offset * i + end = start + offset + tensor_to_print = param[start:end] + file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") + debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, sparse_bucket {sparse_bucket_num}", self._logger) self.empty_cache() diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 08856241..026852f3 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -19,26 +19,18 @@ from .. import is_vllm_v2 -if is_vllm_v2(): - if importlib.util.find_spec("vllm"): - from . import ray_gpu_executor - from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: - from chatlearn.models.vllm.hooks import input_preprocess - from chatlearn.models.vllm.hooks import async_llm_engine - from chatlearn.models.vllm.hooks import llm - from chatlearn.models.vllm.hooks import loader - from chatlearn.models.vllm.hooks import worker_base -else: - if importlib.util.find_spec("vllm"): - import vllm - from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion # pylint: disable=ungrouped-imports - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: - from chatlearn.models.vllm.hooks import sampler - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: - from chatlearn.models.vllm.hooks import llm_engine, logits_processor - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1: - from chatlearn.models.vllm.hooks import worker - else: - from chatlearn.models.vllm.hooks import input_preprocess - from chatlearn.models.vllm.hooks import format_device_name +if importlib.util.find_spec("vllm"): + + from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion + + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: + from chatlearn.models.vllm.hooks.vllm_0_3_0 import * + elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1: + from chatlearn.models.vllm.hooks.vllm_0_5_1 import * + elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + from chatlearn.models.vllm.hooks.vllm_0_6_3 import * + elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_6: + from .vllm_0_6_6 import * + else: + raise RuntimeError( + f"vLLM version expected in {list(member.value for member in VLLMVersion)}, while {CURRENT_VLLM_VERSION}.") diff --git a/chatlearn/models/vllm/hooks/input_preprocess.py b/chatlearn/models/vllm/hooks/input_preprocess.py deleted file mode 100644 index 0f9f4ae0..00000000 --- a/chatlearn/models/vllm/hooks/input_preprocess.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Hooks of vllm-0.6.3 input preprocess to pass prompt text.""" - - -import inspect - -# pylint: disable=unused-import,unused-argument -from vllm.inputs import preprocess - - -source = inspect.getsource(preprocess.InputPreprocessor._extract_prompt_components) -if 'parsed = parse_singleton_prompt(prompt)' in source: - from vllm.inputs.parse import parse_singleton_prompt - - def extract_prompt_components( - self, - prompt, - request_id, - lora_request=None): - ''' - Extract the components of any single encoder or decoder input prompt. - - Arguments: - - * request_id - * prompt: single encoder or decoder input prompt - * lora_request: this is only valid for decoder prompts - - Returns: - - * prompt - * prompt_token_ids - * multi_modal_data - * mm_processor_kwargs (request-level input processor/mapper overrides) - ''' - parsed = parse_singleton_prompt(prompt) - - assert parsed["type"] == "tokens", \ - f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" - - prompt_text = parsed["content"]["prompt"] - prompt_token_ids = parsed["content"]["prompt_token_ids"] - multi_modal_data = parsed["content"].get("multi_modal_data") - mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") - - return (prompt_text, prompt_token_ids, multi_modal_data, - mm_processor_kwargs) - - preprocess.InputPreprocessor._extract_prompt_components = extract_prompt_components diff --git a/chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py new file mode 100644 index 00000000..5cc844e2 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_3_0/__init__.py @@ -0,0 +1,21 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.3.0.""" + +from ... import is_vllm_v2 + +assert not is_vllm_v2(), "vLLM-0.3.0 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`." + +from . import sampler diff --git a/chatlearn/models/vllm/hooks/sampler.py b/chatlearn/models/vllm/hooks/vllm_0_3_0/sampler.py similarity index 100% rename from chatlearn/models/vllm/hooks/sampler.py rename to chatlearn/models/vllm/hooks/vllm_0_3_0/sampler.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py new file mode 100644 index 00000000..e9691dee --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_5_1/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.5.1.""" + +from ... import is_vllm_v2 + +assert not is_vllm_v2(), "vLLM-0.5.1 only supports vLLM Module v1. Set env `ENABLE_VLLM_V2=False`." + +from . import llm_engine +from . import logits_processor +from . import worker diff --git a/chatlearn/models/vllm/hooks/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/llm_engine.py similarity index 100% rename from chatlearn/models/vllm/hooks/llm_engine.py rename to chatlearn/models/vllm/hooks/vllm_0_5_1/llm_engine.py diff --git a/chatlearn/models/vllm/hooks/logits_processor.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/logits_processor.py similarity index 100% rename from chatlearn/models/vllm/hooks/logits_processor.py rename to chatlearn/models/vllm/hooks/vllm_0_5_1/logits_processor.py diff --git a/chatlearn/models/vllm/hooks/worker.py b/chatlearn/models/vllm/hooks/vllm_0_5_1/worker.py similarity index 100% rename from chatlearn/models/vllm/hooks/worker.py rename to chatlearn/models/vllm/hooks/vllm_0_5_1/worker.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py new file mode 100644 index 00000000..efd99405 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.6.3.""" + +from ... import is_vllm_v2 +from . import format_device_name +from . import input_preprocess + +if is_vllm_v2(): + from . import async_llm_engine + from . import llm + from . import loader + from . import ray_gpu_executor + from . import worker_base +else: + from . import llm_engine + from . import logits_processor diff --git a/chatlearn/models/vllm/hooks/async_llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py similarity index 96% rename from chatlearn/models/vllm/hooks/async_llm_engine.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py index 45428241..77f2ed70 100644 --- a/chatlearn/models/vllm/hooks/async_llm_engine.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/async_llm_engine.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine.""" +"""del init_ray_cluster in AsyncLLMEngine.""" from typing import Dict, Optional diff --git a/chatlearn/models/vllm/hooks/format_device_name.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/format_device_name.py similarity index 100% rename from chatlearn/models/vllm/hooks/format_device_name.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/format_device_name.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py new file mode 100644 index 00000000..dbb80d69 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/input_preprocess.py @@ -0,0 +1,55 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 input preprocess to pass prompt text.""" + +# pylint: disable=unused-import,unused-argument +from vllm.inputs import preprocess +from vllm.inputs.parse import parse_singleton_prompt + +def extract_prompt_components( + self, + prompt, + request_id, + lora_request=None): + ''' + Extract the components of any single encoder or decoder input prompt. + + Arguments: + + * request_id + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * prompt + * prompt_token_ids + * multi_modal_data + * mm_processor_kwargs (request-level input processor/mapper overrides) + ''' + parsed = parse_singleton_prompt(prompt) + + assert parsed["type"] == "tokens", \ + f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" + + prompt_text = parsed["content"]["prompt"] + prompt_token_ids = parsed["content"]["prompt_token_ids"] + multi_modal_data = parsed["content"].get("multi_modal_data") + mm_processor_kwargs = parsed["content"].get("mm_processor_kwargs") + + return (prompt_text, prompt_token_ids, multi_modal_data, + mm_processor_kwargs) + +preprocess.InputPreprocessor._extract_prompt_components = extract_prompt_components diff --git a/chatlearn/models/vllm/hooks/llm.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py similarity index 100% rename from chatlearn/models/vllm/hooks/llm.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/llm.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py new file mode 100644 index 00000000..bff1d188 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/llm_engine.py @@ -0,0 +1,30 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.5.1 llm_engine remove __reduce__ function.""" + +import inspect + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine import llm_engine + + +source = inspect.getsource(llm_engine.LLMEngine.__reduce__) +if 'RuntimeError' in source: + def __reduce__(self): + # This is to ensure that the LLMEngine can be referenced in + # the closure used to initialize Ray worker actors + pass + + del llm_engine.LLMEngine.__reduce__ diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py similarity index 99% rename from chatlearn/models/vllm/hooks/loader.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py index 147e1b8a..5a81f452 100644 --- a/chatlearn/models/vllm/hooks/loader.py +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/loader.py @@ -73,7 +73,6 @@ def init(self, load_config): loader.DummyModelLoader.__init__ = init - # add ckpt loading of megatron format def load_model(self, *, model_config, device_config, diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py new file mode 100644 index 00000000..d713bff6 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_3/logits_processor.py @@ -0,0 +1,42 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.5.1 logits_processor to allgather logits of all ranks.""" + +import inspect + +# pylint: disable=wildcard-import,ungrouped-imports +from vllm.model_executor.layers import logits_processor + + +source = inspect.getsource(logits_processor.LogitsProcessor._get_logits) +if 'tensor_model_parallel_gather' in source: + import torch + from typing import Optional + from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding + def _get_logits(self, hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = lm_head.linear_method.apply(lm_head, + hidden_states, + bias=embedding_bias) + from vllm.distributed.communication_op import tensor_model_parallel_all_gather # pylint: disable=import-outside-toplevel + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + logits_processor.LogitsProcessor._get_logits = _get_logits diff --git a/chatlearn/models/vllm/hooks/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/ray_gpu_executor.py similarity index 100% rename from chatlearn/models/vllm/hooks/ray_gpu_executor.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/ray_gpu_executor.py diff --git a/chatlearn/models/vllm/hooks/worker_base.py b/chatlearn/models/vllm/hooks/vllm_0_6_3/worker_base.py similarity index 100% rename from chatlearn/models/vllm/hooks/worker_base.py rename to chatlearn/models/vllm/hooks/vllm_0_6_3/worker_base.py diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py new file mode 100644 index 00000000..8e15c60c --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Additional hooks of vllm-0.6.6.""" + +from ... import is_vllm_v2 + +assert is_vllm_v2(), "vLLM-0.6.6 only supports vLLM Module v2." + +from . import async_llm_engine +from . import input_preprocess +from . import llm +from . import llm_engine +from . import loader +from . import ray_gpu_executor +from . import worker_base diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py new file mode 100644 index 00000000..513a068c --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/async_llm_engine.py @@ -0,0 +1,53 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""del init_ray_cluster in AsyncLLMEngine.""" + +from typing import Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument,not-callable +from vllm.config import VllmConfig +from vllm.engine import async_llm_engine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.usage.usage_lib import UsageContext + +@classmethod +def from_engine_args(cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[VllmConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + if engine_config is None: + engine_config = engine_args.create_engine_config(usage_context) + + executor_class = cls._get_executor_cls(engine_config) + + # Create the async LLM engine. + engine = cls( + vllm_config=engine_config, + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + +async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py new file mode 100644 index 00000000..a5d17acb --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/input_preprocess.py @@ -0,0 +1,71 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.6 input preprocess to pass prompt text.""" + +# pylint: disable=unused-import,unused-argument +from vllm.inputs import preprocess +from vllm.inputs.data import token_inputs +from vllm.inputs.parse import parse_singleton_prompt + + +def _prompt_to_llm_inputs( + self, + prompt, + request_id: str, + lora_request=None, +): + """ + Extract the singleton inputs from a prompt. + + Arguments: + + * request_id + * prompt: single encoder or decoder input prompt + * lora_request: this is only valid for decoder prompts + + Returns: + + * :class:`SingletonInputs` instance + """ + parsed = parse_singleton_prompt(prompt) + + assert parsed["type"] == "tokens", \ + f"you must pass prompt_token_ids when add request to scheduler. while prompt {prompt}" + + if parsed["type"] == "tokens": + tokens_content = parsed["content"] + + prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") + multi_modal_data = tokens_content.get("multi_modal_data") + mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") + + if multi_modal_data is not None and self._can_process_multimodal(): + return self._process_multimodal( + prompt_token_ids, + multi_modal_data, + mm_processor_kwargs, + lora_request=lora_request, + ) + + return token_inputs( + prompt=tokens_content["prompt"], + prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=mm_processor_kwargs, + ) + +preprocess.InputPreprocessor._prompt_to_llm_inputs = _prompt_to_llm_inputs diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py new file mode 100644 index 00000000..1028bf0c --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm.py @@ -0,0 +1,109 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs.""" + +from typing import Any, Dict, Optional, Union + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine.arg_utils import AsyncEngineArgs, HfOverrides, TaskOption, PoolerConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints import llm +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter + +def init(self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + allowed_local_media_path: str = "", + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + hf_overrides: Optional[HfOverrides] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + # After positional args are removed, move this right below `model` + task: TaskOption = "auto", + override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, + **kwargs,) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + if compilation_config is not None: + if isinstance(compilation_config, (int, dict)): + compilation_config_instance = CompilationConfig.from_cli( + str(compilation_config)) + else: + compilation_config_instance = compilation_config + else: + compilation_config_instance = None + + engine_args = AsyncEngineArgs( + model=model, + task=task, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + allowed_local_media_path=allowed_local_media_path, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + hf_overrides=hf_overrides, + mm_processor_kwargs=mm_processor_kwargs, + override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, + **kwargs, + ) + # Logic to switch between engines is done at runtime instead of import + # to avoid import order issues + self.engine_class = self.get_engine_class() + + # TODO(rob): enable mp by default (issue with fork vs spawn) + self.llm_engine = self.engine_class.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS) + + self.request_counter = Counter() + +llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py new file mode 100644 index 00000000..ea24cc86 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/llm_engine.py @@ -0,0 +1,58 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.5.1 llm_engine remove __reduce__ function.""" + +import inspect +from typing import Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine import llm_engine +from vllm.engine.metrics_types import StatLoggerBase +from vllm.executor.ray_gpu_executor import RayGPUExecutor +from vllm.usage.usage_lib import UsageContext + + +source = inspect.getsource(llm_engine.LLMEngine.__reduce__) +if 'RuntimeError' in source: + def __reduce__(self): + # This is to ensure that the LLMEngine can be referenced in + # the closure used to initialize Ray worker actors + pass + + del llm_engine.LLMEngine.__reduce__ + + +@classmethod +def from_engine_args( + cls, + engine_args, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, +) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_config = engine_args.create_engine_config(usage_context) + executor_class = RayGPUExecutor + # Create the LLM engine. + engine = cls( # pylint: disable=not-callable + vllm_config=engine_config, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + + return engine +llm_engine.LLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py new file mode 100644 index 00000000..6d46c572 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/loader.py @@ -0,0 +1,115 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 loader to load ckpt of megatron format.""" + + +import torch + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.model_executor.model_loader import loader +from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model +from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models import llama +from vllm.model_executor.models import qwen2, qwen2_moe +from vllm.config import VllmConfig + +from chatlearn.utils.vllm_import_helper import LlamaForCausalLM +from chatlearn.utils.vllm_import_helper import QWenLMHeadModel +from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM +from chatlearn.utils.vllm_import_helper import Qwen2MoeForCausalLM +from chatlearn.utils.vllm_import_helper import get_model_architecture +from chatlearn.utils.utils import get_use_legacy_models + +from chatlearn.utils.vllm_utils import ( + convert_llama_state_dict_from_megatron_to_vllm, + convert_llama_state_dict_from_mcore_to_vllm, + convert_qwen_state_dict_from_megatron_to_vllm, + load_checkpoint +) + +def load_weights(self, model_args): + torch.distributed.barrier() + self.model_args = model_args + load_checkpoint(self, None, None, model_args=model_args) + torch.distributed.barrier() + +def load_state_dict(self, state_dict, strict=True, assign=False): + qwen_version = None + if isinstance(self, LlamaForCausalLM): + use_legacy_models = get_use_legacy_models(self.model_args) + if use_legacy_models: + convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm + else: + convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm + elif isinstance(self, QWenLMHeadModel): + qwen_version = 1.0 + convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm + elif isinstance(self, Qwen2ForCausalLM) or (Qwen2MoeForCausalLM is not None and isinstance(self, Qwen2MoeForCausalLM)): + qwen_version = 2.0 + convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm + else: + raise RuntimeError(f"Unsupported model for vllm backend. \ + support [LlamaForCausalLM, QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM] only, while {self}") + + state_dict = convert_state_dict_internal(self.model_args, self.config, qwen_version=qwen_version) + super(type(self), self).load_state_dict(state_dict, strict=strict) + + +def init(self, load_config): + # remove 'Model loader extra config' assert. + self.load_config = load_config + +loader.DummyModelLoader.__init__ = init + + +# add ckpt loading of megatron format +def load_model(self, vllm_config: VllmConfig):# -> nn.Module: + device_config = vllm_config.device_config + model_config = vllm_config.model_config + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(vllm_config=vllm_config) + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + if self.load_config.model_loader_extra_config.get("need_load_ckpt", True) and \ + self.load_config.model_loader_extra_config["load"] is not None: + qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict + qwen2.Qwen2ForCausalLM.load_weights = load_weights + qwen2_moe.Qwen2MoeForCausalLM.load_state_dict = load_state_dict + qwen2_moe.Qwen2MoeForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights + model.load_weights(self.load_config.model_loader_extra_config) + else: + # For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) + return model.eval() + + +loader.DummyModelLoader.load_model = load_model diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py new file mode 100644 index 00000000..19efb2f2 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/ray_gpu_executor.py @@ -0,0 +1,253 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hook _init_workers_ray""" + +from collections import defaultdict +from typing import Dict, List, Optional + +from vllm import envs +from vllm.executor.ray_gpu_executor import RayGPUExecutor +from vllm.executor.ray_utils import RayWorkerWrapper, ray +from vllm.logger import init_logger +from vllm.utils import (get_distributed_init_method, + get_ip, get_open_port) + +from chatlearn.utils.global_vars import get_vllm_actors + +logger = init_logger(__name__) + + +# modified based on https://github.com/vllm-project/vllm/blob/6aa6020f9bd4c1e414c10f7bd3a7c2555f1950b2/vllm/executor/ray_gpu_executor.py#L109 +def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): # pylint: disable=unused-argument,unused-variable + if (self.parallel_config.tensor_parallel_size == 1 + and self.parallel_config.pipeline_parallel_size == 1): + # For single GPU case, we use a ray worker with constrained memory. + num_gpus = self.cache_config.gpu_memory_utilization + else: + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: Optional[RayWorkerWrapper] = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerWrapper] = [] + + # Used in ray compiled DAG: indexed first by PP rank, + # and then TP rank. In other words, the inner list is + # the TP group of workers for a PP rank. + self.pp_tp_workers: List[List[RayWorkerWrapper]] = [] + + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker) + + # Create the workers. + driver_ip = get_ip() + workers = [] + workers = get_vllm_actors() + # if self.use_ray_spmd_worker: + # workers = vllm_workers + # else: + # for bundle_id, bundle in enumerate(placement_group.bundle_specs): + # if not bundle.get("GPU", 0): + # continue + # scheduling_strategy = PlacementGroupSchedulingStrategy( + # placement_group=placement_group, + # placement_group_capture_child_tasks=True, + # placement_group_bundle_index=bundle_id, + # ) + + # worker = ray.remote( + # num_cpus=0, + # num_gpus=num_gpus, + # scheduling_strategy=scheduling_strategy, + # **ray_remote_kwargs, + # )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) + # workers.append(worker) + + worker_ip_refs = [ + worker.get_node_ip.remote() # type: ignore[attr-defined] + for worker in workers + ] + worker_ips = ray.get(worker_ip_refs) + + if not self.use_ray_spmd_worker: + for i in range(len(workers)): # pylint: disable=consider-using-enumerate + worker = workers[i] + worker_ip = worker_ips[i] + if self.driver_dummy_worker is None and worker_ip == driver_ip: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + self.driver_worker = RayWorkerWrapper( + vllm_config=self.vllm_config) + workers.pop(i) + worker_ips.pop(i) + self.workers = workers + break + else: + self.workers = workers + + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) + if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + worker_to_ip = dict(zip(self.workers, worker_ips)) + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = worker_to_ip[worker] + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = [] + for worker in [self.driver_dummy_worker] + self.workers: + if worker is None: + # driver_dummy_worker can be None when using ray spmd worker. + continue + worker_node_and_gpu_ids.append( + ray.get(worker.get_node_and_gpu_ids.remote()) \ + ) # type: ignore + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips + [driver_ip]) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP`" + " environment variable, make sure it is unique for" + " each node.") + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=0,#node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + + self._run_workers("init_device") + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) + + if self.use_ray_spmd_worker: + for pp_rank in range(self.parallel_config.pipeline_parallel_size): + self.pp_tp_workers.append([]) + for tp_rank in range( + self.parallel_config.tensor_parallel_size): + # PP=2, TP=4 + # pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]] + rank = (pp_rank * self.parallel_config.tensor_parallel_size + ) + tp_rank + assert len(self.pp_tp_workers[pp_rank]) == tp_rank + assert pp_rank < len(self.pp_tp_workers) + self.pp_tp_workers[pp_rank].append(self.workers[rank]) + + # This is the list of workers that are rank 0 of each TP group EXCEPT + # global rank 0. These are the workers that will broadcast to the + # rest of the workers. + self.tp_driver_workers: List[RayWorkerWrapper] = [] + # This is the list of workers that are not drivers and not the first + # worker in a TP group. These are the workers that will be + # broadcasted to. + self.non_driver_workers: List[RayWorkerWrapper] = [] + + # Enforce rank order for correct rank to return final output. + for index, worker in enumerate(self.workers): + # The driver worker is rank 0 and not in self.workers. + rank = index + 1 + if rank % self.parallel_config.tensor_parallel_size == 0: + self.tp_driver_workers.append(worker) + else: + self.non_driver_workers.append(worker) + +RayGPUExecutor._init_workers_ray = _init_workers_ray diff --git a/chatlearn/models/vllm/hooks/vllm_0_6_6/worker_base.py b/chatlearn/models/vllm/hooks/vllm_0_6_6/worker_base.py new file mode 100644 index 00000000..54100726 --- /dev/null +++ b/chatlearn/models/vllm/hooks/vllm_0_6_6/worker_base.py @@ -0,0 +1,46 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 worker_base to update execute_method.""" + +# pylint: disable=unused-import,wildcard-import +from vllm.worker import worker_base +from vllm.worker.worker_base import logger + + +del worker_base.WorkerWrapperBase.__getattr__ + +def execute_method(self, method, *args, **kwargs): + try: + if self.worker is None: + target = self + else: + if hasattr(self.worker, method): + target = self.worker + else: + target = self + #print(f"debug target: {target} method: {method}") + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + +worker_base.WorkerWrapperBase.execute_method = execute_method diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 072d9c24..fbc54069 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -25,6 +25,7 @@ from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper +from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion from chatlearn.utils.global_vars import set_vllm_actors from chatlearn.utils.vllm_import_helper import parallel_state from chatlearn.utils.vllm_import_helper import get_pipeline_model_parallel_rank @@ -49,8 +50,14 @@ def __init__(self, *args, **kwargs): assert common_methods == {'__init__'}, \ f"Expected only '__init__' as common method for TorchModule and RayWorkerWrapper, but got {common_methods}" self.local_rank = 0 - if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: - RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: + RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called + else: + if 'vllm_actor_type' in kwargs and 'worker' == kwargs['vllm_actor_type']: + vllm_config = self.init_engine_args() + RayWorkerWrapper.__init__(self, vllm_config=vllm_config) # pylint: disable=non-parent-init-called + os.environ['VLLM_HOST_IP'] = self.get_address() self.tokenizer = None @@ -75,6 +82,52 @@ def add_extra_args(self, parser): help='Timeout minutes for torch.distributed.') return parser + def init_engine_args(self): + dtype = self.model_args.get("dtype", "bfloat16") + if self.model_args.get("fp16", False): + dtype = "float16" + + load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) + if load_format == LoadFormat.DUMMY: + self.model_args["need_load_ckpt"] = self.src_parameter_model is None + model_loader_extra_config = self.model_args + else: + model_loader_extra_config = None + + if self.model_args.get("apply_replica_id_to_seed", True): + seed = self.model_args.get("seed", 0) + self.replica_id + else: + seed = self.model_args.get("seed", 0) + + from vllm.engine.arg_utils import AsyncEngineArgs # pylint: disable=import-outside-toplevel + from vllm.usage.usage_lib import UsageContext # pylint: disable=import-outside-toplevel + + self.engine_args = AsyncEngineArgs( + model=self.model_args['tokenizer'], + tokenizer=self.model_args['tokenizer'], + max_seq_len_to_capture=self.model_args.get("seq_length"), + seed=seed, + # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. + load_format=load_format, + model_loader_extra_config=model_loader_extra_config, + # parallelism strategy + tensor_parallel_size=self.module_args.tensor_model_parallel_size, + pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, + dtype=dtype, + # scheduling strategy + max_num_seqs=self.module_args.generation_batch_size, + max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), + num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), + gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), + # logger + disable_log_requests=self.model_args.get("disable_log_requests", True), + disable_log_stats=self.model_args.get("disable_log_stats", True), + trust_remote_code=True, + enforce_eager=self.model_args.get("enforce_eager", False), + disable_custom_all_reduce=True, + distributed_executor_backend="ray") + return self.engine_args.create_engine_config(usage_context=UsageContext.ENGINE_CONTEXT) + def init(self): """ :meta private: @@ -256,9 +309,7 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids): return inputs - def generate_vllm(self, query, is_eval, is_first_run=True): - if is_first_run: # using for multi-round generate - self.reinit_cache_engine() + def preprocess_inputs(self, query, is_eval): prompt_key = self.model_args.get("vllm_prompt_key", "prompt") input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids") @@ -286,13 +337,24 @@ def generate_vllm(self, query, is_eval, is_first_run=True): parsed_prompts.append(item) sampling_params.append(sampling_param) + return parsed_prompts, sampling_params + + def run_vllm(self, parsed_prompts, sampling_params): outputs = self.llm.generate( parsed_prompts, sampling_params, - use_tqdm=True, + use_tqdm=True ) return outputs + def generate_vllm(self, query, is_eval, is_first_run=True): + if is_first_run: # using for multi-round generate + self.reinit_cache_engine() + parsed_prompts, sampling_params = self.preprocess_inputs(query, is_eval) + + outputs = self.run_vllm(parsed_prompts, sampling_params) + return outputs + def is_last_rank(self): return True @@ -343,6 +405,21 @@ def pipeline_parallel_rank(self): """ return get_pipeline_model_parallel_rank() + def tensor_model_parallel_size(self): + return self.tensor_and_expert_model_parallel_size() + + def expert_model_parallel_size(self): + return 1 + + def tensor_and_expert_model_parallel_size(self): + """ + get tensor_and_expert_model_parallel_size + :meta private: + """ + # vLLM not supported to enable expert parallel size + # thus: tensor_and_expert_model_parallel_size = tensor_parallel_size + return parallel_state.get_tensor_model_parallel_world_size() + def model_setup_for_workers(self): self.llm.llm_engine.model_executor._run_workers("model_setup") diff --git a/chatlearn/runtime/decorator.py b/chatlearn/runtime/decorator.py index 07ceff75..7034126c 100644 --- a/chatlearn/runtime/decorator.py +++ b/chatlearn/runtime/decorator.py @@ -168,6 +168,7 @@ def get_kwarg(key): # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ or isinstance(self, VLLMModuleV2): + # print(f"replica_id_{self.replica_id} results: {len(results)} {results} input_data: {input_data}") final_results = concat_along_batch(results) else: if 'iteration' in inspect.signature(func).parameters: diff --git a/chatlearn/runtime/dist_actor.py b/chatlearn/runtime/dist_actor.py index b82c0283..6819d879 100644 --- a/chatlearn/runtime/dist_actor.py +++ b/chatlearn/runtime/dist_actor.py @@ -30,6 +30,7 @@ if vllm_exist: from chatlearn.models.vllm_module import VLLMModule from chatlearn.models.vllm_module_v2 import VLLMModuleV2 + from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion RAY_REMOTE = "remote" @@ -234,12 +235,17 @@ def __init__(self, *args, **kwargs): self.vllm_engine = None def create_actor(self, num_gpus, placement_group, group_index): - kwargs = { - "worker_module_name": "vllm.worker.worker", - "worker_class_name": "Worker", - "worker_class_fn": None, - "trust_remote_code": True, - } + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + kwargs = { + "worker_module_name": "vllm.worker.worker", + "worker_class_name": "Worker", + "worker_class_fn": None, + "trust_remote_code": True, + } + else: + kwargs = { + "vllm_actor_type" : "worker" + } self._create_actor(self.model.__class__, num_gpus, placement_group, group_index, **kwargs) def create_engine_actor(self, num_gpus, placement_group, group_index): diff --git a/chatlearn/synchronizer/megatron_vllm.py b/chatlearn/synchronizer/megatron_vllm.py index ae334851..b06b6ca7 100644 --- a/chatlearn/synchronizer/megatron_vllm.py +++ b/chatlearn/synchronizer/megatron_vllm.py @@ -33,6 +33,7 @@ class MegatronVllmSync(BaseSync): def __init__(self, src_model, dst_model): super().__init__(src_model, dst_model) self.src_module_args = src_model.module_args + self.dst_module_args = dst_model.module_args self.is_parameter_changed = True @abstractmethod @@ -231,7 +232,7 @@ def allgather_routed_experts(self, name, params_to_sync, group_name, tp_rank): # "Please export `QWEN_VERSION` as `qwen_moe_v1`." ) - def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): + def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group, model_name, rank, world_size): """ This function is applicable for synchronizing parameters from QWen with HEP enabled to vLLM. In HEP, routed experts are split into a total number of EP size * TP size. @@ -240,6 +241,9 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): if self.sync_map._to_alltoall_routed_experts_dict is None: return params_to_sync, False + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"1 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + to_alltoall_routed_experts_dict = self.sync_map._to_alltoall_routed_experts_dict layer_re = to_alltoall_routed_experts_dict["layer_re"] to_regroup_modules_list = to_alltoall_routed_experts_dict["modules"] @@ -248,20 +252,33 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): if m is None: return params_to_sync, False + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"2 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + op_name = m.group(2) if op_name in to_regroup_modules_list: tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"] hep_size = tp_size * ep_size moe_num_experts = self.src_module_args.args_dict["moe_num_experts"] + local_num_experts = moe_num_experts // hep_size hidden_size = self.src_module_args.args_dict["hidden_size"] if "dense_h_to_4h" in op_name: # w13_weight # regroup among difference tp slices param = params_to_sync.view((moe_num_experts, -1, hidden_size)) + # if "layers.0" in name: + # offset = 4 + # num_prints = param.shape[0] // offset + # with open(f"/workspace/code/cmd/moelite_scripts/chatlearn_w13_{rank}_before_tp4ep2.txt", "a+") as file: + # for i in range(num_prints): + # start = offset * i + # end = start + offset + # tensor_to_print = param[start:end] + # file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") param = param.reshape((local_num_experts * 2, -1, hidden_size)) - params = list(param.chunk(tp_size, dim=1)) + params = list(param.chunk(hep_size, dim=1)) # reorder w1 and w3 params_list = [] while params: @@ -276,36 +293,69 @@ def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group): del params_to_sync output = [ torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device) - for i in range(tp_size) + for i in range(hep_size) ] + # if rank == 0 and "layers.0" in name: + # offset = 4 + # for idx, ele in enumerate(params_list): + # num_prints = ele.shape[0] // offset + # with open(f"/workspace/code/cmd/moelite_scripts/chatlearn_w13_slice_{rank}_tp4ep2.txt", "a+") as file: + # for i in range(num_prints): + # start = offset * i + # end = start + offset + # tensor_to_print = ele[start:end] + # file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") + torch.distributed.all_to_all(output, params_list, group=comm_group) del params_list params_to_sync = torch.cat(output, dim=0).contiguous() + # if "layers.0" in name: + # offset = 4 + # num_prints = params_to_sync.shape[0] // offset + # with open(f"/workspace/code/cmd/moelite_scripts/chatlearn_w13_{rank}_after_tp4ep2.txt", "a+") as file: + # for i in range(num_prints): + # start = offset * i + # end = start + offset + # tensor_to_print = params_to_sync[start:end] + # file.write(name + f"_{i}:" + str(tensor_to_print.cpu()) + "\n") + del output else: # w2_weight param = params_to_sync.view((local_num_experts, -1, hidden_size)) - params = list(param.chunk(tp_size, dim=1)) + params = list(param.chunk(hep_size, dim=1)) params_list = [ele.contiguous() for ele in params] del param del params del params_to_sync output = [ torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device) - for i in range(tp_size) + for i in range(hep_size) ] + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"4.1 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + torch.distributed.all_to_all(output, params_list, group=comm_group) + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"4.2 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + del params_list params_to_sync = torch.cat(output, dim=0).transpose(1, 2).contiguous() del output + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"f.1 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + return params_to_sync, True else: + # with open(f"/workspace/code/cmd/moelite_scripts/ps_{model_name}_{rank}_{world_size}_1.txt", "a+") as file: + # file.write(f"f.2 debug alltoall rank: {rank}/{world_size} in comm group {id(comm_group)} {name}" + "\n") + return params_to_sync, False - def alltoall_routed_experts(self, name, params_to_sync, comm_group): # pylint: disable=unused-argument + def alltoall_routed_experts(self, name, params_to_sync, comm_group, model_name="default", rank=0, world_size=1): # pylint: disable=unused-argument megatron_version = get_megatron_version() if megatron_version == MegatronVersion.V4: - return self.alltoall_routed_experts_from_hep(name, params_to_sync, comm_group) + return self.alltoall_routed_experts_from_hep(name, params_to_sync, comm_group, model_name, rank, world_size) else: raise NotImplementedError( "ChatLearn does not support all-to-all routed experts for Megatron-LM, but supports QWen with HEP enabled. " @@ -329,8 +379,9 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition): if "attention.query_key_value" in name or \ "self_attention.query_key_value" in name or \ "self_attention.linear_qkv" in name: - tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] - heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size + src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] + dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"] + heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"] param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:] @@ -348,23 +399,42 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition): param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) del param_data_list else: - _num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size \ - if self.src_module_args.args_dict["group_query_attention"] else heads - if to_fix_qkv_ordering_dict is not None or _num_query_groups == 1: + if self.src_module_args.args_dict["group_query_attention"]: + num_query_groups = self.src_module_args.args_dict["num_query_groups"] + assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], ( + f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of " + f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config." + ) + src_num_query_groups_per_replica = num_query_groups // src_tp_size + if dst_tp_size >= num_query_groups: + num_dst_kv_head_replicas = dst_tp_size // num_query_groups + else: + num_dst_kv_head_replicas = 1 + else: + src_num_query_groups_per_replica = heads + num_dst_kv_head_replicas = 1 + + if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1: if len(param_data_shape) == 1: - param_data = param_data.view((heads + 2 * _num_query_groups, hidden_size_per_head)) + param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head)) else: param_data = param_data.view( - (heads + 2 * _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])) + (heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])) param_data_list = [] head_offset = heads // tp_divition for idx in range(tp_divition): q_start = idx * head_offset q_end = q_start + head_offset - k_start = (heads + idx) if _num_query_groups // tp_divition else heads - k_end = k_start + 1 - v_start = k_start + _num_query_groups - v_end = v_start + 1 + if num_dst_kv_head_replicas == 1: + k_start = (heads + idx) if src_num_query_groups_per_replica // tp_divition else heads + k_end = k_start + 1 + v_start = k_start + src_num_query_groups_per_replica + v_end = v_start + 1 + else: + k_start = heads + idx // num_dst_kv_head_replicas + k_end = k_start + 1 + v_start = k_start + src_num_query_groups_per_replica + v_end = v_start + 1 q_proj = param_data[q_start:q_end].contiguous() k_proj = param_data[k_start:k_end].contiguous() diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 2267d80f..d922d603 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -14,11 +14,13 @@ # ============================================================================== """Sync parameters""" +import os +import time import concurrent.futures import traceback from collections import defaultdict from concurrent.futures import ThreadPoolExecutor -from itertools import cycle +from itertools import cycle, permutations, combinations from typing import List, Dict import torch @@ -85,9 +87,14 @@ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): [ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER, ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL], \ f"Only support 'allgather' or 'alltoall' for routed expert regrouping, while {self._comm_type_to_regroup_routed_experts}" if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL: - if self.num_dst_tensor_parallel != self.num_src_tensor_parallel: - logger.warning("Only support ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL when src tp eqs dst tp, use 'allgather' instead.") + print(f"debug self.num_dst_tensor_parallel: {self.num_dst_tensor_parallel}") + print(f"debug self.num_dst_expert_parallel: {self.num_dst_expert_parallel}") + print(f"debug self.num_src_tensor_parallel: {self.num_src_tensor_parallel}") + print(f"debug self.num_src_expert_parallel: {self.num_src_expert_parallel}") + if self.num_dst_tensor_parallel * self.num_dst_expert_parallel != self.num_src_tensor_parallel * self.num_src_expert_parallel: + logger.info("Only support ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL when src tp eqs dst tp, use 'allgather' instead.") self._comm_type_to_regroup_routed_experts = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER + logger.info(f"Set ROUTED_EXPERT_REGROUPING_COMM_TYPE = {self._comm_type_to_regroup_routed_experts}.") self.sorted_send_actors = None self.sorted_send_actors_stage2 = None self.actor2synchronizer = {} @@ -303,8 +310,8 @@ def build_rank_mapping(self, add_recv_actor_fn=None): if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None): return - assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ - f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" + #assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ + # f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: @@ -330,6 +337,7 @@ def split_ranks_by_tp_and_ep_size(ranks, j = i // pipe_map_interval for src_rank, dst_rank in zip(src_tp_group, dst_replica_ranks_group[j]): add_recv_actor_fn(src_rank, dst_rank) + logger.info(f"20250131 Sending from {src_rank} to {dst_rank}") # pylint: disable=unused-argument def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): @@ -457,8 +465,8 @@ def p2p_pair_grouping(tuples): for tuples in dst_tp_group: p2p_pair_grouping(tuples) - logger.debug(f"comm pair_list : {pair_list}") - logger.debug(f"comm p2p_list : {p2p_list}") + logger.info(f"comm pair_list : {pair_list}") + logger.info(f"comm p2p_list : {p2p_list}") def _clear_sync_send_recv_parameters(self, rank_mappings:List): if len(rank_mappings) == 0: @@ -594,7 +602,9 @@ def sync_broadcast_two_stage(self, actors, group_name, requires_grad=None, stage send_actor = actors[0] for rank, recv_actor in enumerate(actors[1:]): if stage2: - self.set_sync_param_names_stage2(send_actor, recv_actor, self.actor2rank[recv_actor], requires_grad, filter_fn, param_group) + s_names, d_names = self.set_sync_param_names_stage2(send_actor, recv_actor, self.actor2rank[recv_actor], requires_grad, filter_fn, param_group) + # for s_name, d_name in zip(s_names, d_names): + # logger.info(f"stage 2 sync: {len(s_names)} -> {len(d_names)}") else: self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group) pipe_stage = self.get_actor_pipe_rank(send_actor) @@ -661,6 +671,7 @@ def sync_alltoall(self, actors, requires_grad=None, filter_fn=None): self.set_sync_param_names(actor, actor, requires_grad, filter_fn, param_group="routed", should_map_name=False) pipe_stage = self.get_actor_pipe_rank(actors[0]) refs = [] + logger.info(f"debug alltoall among {[self.actor2rank[actor] for actor in actors]}") for actor in actors: ref = actor.alltoall_routed_expert_parameter.remote(pipe_stage) refs.append(ref) @@ -849,13 +860,69 @@ def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): assert len(send_recv_actor_mappings) == len(sorted_send_actors) return sorted_send_actors + def sync_broadcast_second_stage_internal(self, group_name, thread_group, requires_grad=None, filter_fn=None, param_group="default"): + max_workers = len(thread_group) + max_workers = min(8, max_workers) + logger.info(f"debug thread_group {group_name}: {[(self.actor2rank[ele[0]], self.actor2rank[ele[1]]) for ele in thread_group]}") + logger.info(f"Use {max_workers} workers for second_stage_internal broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for idx, actor_group in enumerate(thread_group): + send_actor, recv_actor = actor_group + group_name_with_idx = f"{group_name}_{idx}" + actor_groups, finalized_group_name = self.create_broadcast_group( + send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group + ) + futures.append(executor.submit( + self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, True, filter_fn, param_group)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) + + + def sync_broadcast_second_stage(self, group_name, thread_groups, requires_grad=None, filter_fn=None, param_group="default"): + + tp_size = self.num_dst_tensor_parallel + num_thread_groups = len(thread_groups) // tp_size + new_thread_groups = [thread_groups[tp_size*i:tp_size*(i+1)] for i in range(num_thread_groups)] + + if not new_thread_groups: + new_thread_groups = [thread_groups] + max_workers = len(new_thread_groups) + + logger.info(f"Use {max_workers} workers for second_stage broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for idx, thread_group in enumerate(new_thread_groups): + # send_actor, recv_actor = actor_group + group_name_with_idx = f"{group_name}_{idx}" + # actor_groups, finalized_group_name = self.create_broadcast_group( + # send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group + # ) + futures.append(executor.submit( + self.sync_broadcast_second_stage_internal, group_name_with_idx, thread_group, requires_grad, filter_fn, param_group)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) + def sync_broadcast_multi_threads( self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None, group_name=None, stage2=False, filter_fn=None, param_group="default"): - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] + log_debug = bool(int(os.getenv("log_debug", "0"))) + + log_debug = False + if log_debug: for send_actor in sorted_send_actors: recv_actors = send_recv_actor_mappings[send_actor] + logger.info(f"sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: if stage2: for idx, recv_actor in enumerate(recv_actors): @@ -863,31 +930,87 @@ def sync_broadcast_multi_threads( actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group ) - futures.append(executor.submit( - self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group - )) + logger.info(f"stage2 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[recv_actor]]}") + + self.sync_broadcast_two_stage(actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group) + logger.info(f"stage2 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[recv_actor]]} finished.") + else: actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, recv_actors, group_name=group_name, param_group=param_group ) - futures.append(executor.submit( - self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group - )) - else: - raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - traceback.print_exc() - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") + self.sync_broadcast_two_stage(actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group) + # logger.info(f"stage 1 sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]} finished.") + else: + if stage2: + thread_group = [] + for send_actor in sorted_send_actors: + recv_actors = send_recv_actor_mappings[send_actor] + for recv_actor in recv_actors: + thread_group.append((send_actor, recv_actor)) + actor_groups_to_sync = [] + for group in thread_group: + new_actor_group_flag = True + for idx, actor_groups in enumerate(actor_groups_to_sync): + in_actor_group = False + for jdx, actor_group in enumerate(actor_groups): + if group[0] in actor_group or group[1] in actor_group: + in_actor_group = True + if not in_actor_group: + new_actor_group_flag = False + actor_groups_to_sync[idx].append(group) + break + if new_actor_group_flag or not actor_groups_to_sync: + actor_groups_to_sync.append([group]) + log_list = [] + for thread_group in actor_groups_to_sync: + log_list.append([(self.actor2rank[ele[0]], self.actor2rank[ele[1]]) for ele in thread_group]) + + logger.info(f"debug actor_groups_to_sync {group_name}: {log_list}") + + for group_idx, actor_groups in enumerate(actor_groups_to_sync): + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + self.sync_broadcast_second_stage( + f"{group_name}_{group_idx}", + actor_groups, + requires_grad, + filter_fn, + param_group + ) + else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") + else: + max_workers = len(sorted_send_actors) + logger.info(f"Use {max_workers} workers for first_stage broadcasting.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for send_actor in sorted_send_actors: + recv_actors = send_recv_actor_mappings[send_actor] + logger.info(f"sending from {self.actor2rank[send_actor]} to {[self.actor2rank[actor] for actor in recv_actors]}") + if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: + actor_groups, finalized_group_name = self.create_broadcast_group( + send_actor, recv_actors, group_name=group_name, param_group=param_group + ) + futures.append(executor.submit( + self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group + )) + else: + raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + traceback.print_exc() + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) def sync_allgather_multi_threads( self, send_actors, max_workers=1, requires_grad=None, group_name=None, filter_fn=None ): send_actors_to_allgather_routed_experts = send_actors[0] + logger.info(f"Use {max_workers} workers for allgather multiprocessing.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for allgather_actors in send_actors_to_allgather_routed_experts: @@ -906,18 +1029,27 @@ def sync_alltoall_multi_threads( self, send_actors, max_workers=1, requires_grad=None, filter_fn=None ): send_actors_to_alltoall_routed_experts = send_actors[0] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] + if False: for actor_groups in send_actors_to_alltoall_routed_experts: - futures.append(executor.submit( - self.sync_alltoall, actor_groups, requires_grad, filter_fn=filter_fn - )) - for _future in concurrent.futures.as_completed(futures): - try: - _future.result() - except Exception as e: - raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from - concurrent.futures.wait(futures) + logger.info(f"1 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") + self.sync_alltoall(actor_groups, requires_grad, filter_fn=filter_fn) + logger.info(f"2 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") + else: + max_workers = len(send_actors_to_alltoall_routed_experts) + logger.info(f"Use {max_workers} workers for alltoall multiprocessing.") + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [] + for actor_groups in send_actors_to_alltoall_routed_experts: + logger.info(f"1 debug alltoall among {[self.actor2rank[actor] for actor in actor_groups]}") + futures.append(executor.submit( + self.sync_alltoall, actor_groups, requires_grad, filter_fn=filter_fn + )) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from + concurrent.futures.wait(futures) def check_and_setup_collective_group(self): if not self._is_collective_group_created: @@ -997,7 +1129,7 @@ def validate_sync_results_parallel(self, actor_mappings_list:List, requires_grad def _calculate_max_workers(self, sorted_send_actors, actor_mappings=None): max_workers = get_args().runtime_args.param_sync_max_workers if max_workers is None: - max_workers = max(self.src_model.total_gpu // 8, 1) + max_workers = max(self.src_model.total_gpu // self.num_src_pipeline_stage, 1) if max_workers == -1: if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: max_workers = len(sorted_send_actors) @@ -1030,20 +1162,30 @@ def _multi_thread_sync_for_tp_num_mapping_gt_1( actor_mappings_stage2 = actor_mappings[1] # stage 1 + s_time = time.time() sorted_send_actors_stage1 = list(actor_mappings_stage1.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage1, actor_mappings_stage1) group_name = self.group_name + "_inter_comm" + logger.info(f"1 debug param sync:start to sync {group_name}") self.sync_broadcast_multi_threads( sorted_send_actors_stage1, actor_mappings_stage1, max_workers, requires_grad, group_name=group_name, stage2=False, filter_fn=filter_fn, param_group=param_group ) + e_time = time.time() + logger.info(f"debug param sync cost stage1: {e_time - s_time}") # stage 2 + s_time = time.time() sorted_send_actors_stage2 = list(actor_mappings_stage2.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage2, actor_mappings_stage2) group_name = self.group_name + "_intra_comm" + logger.info(f"2 debug param sync:start to sync {group_name} param_group: {param_group}") + os.environ["log_debug"] = "1" self.sync_broadcast_multi_threads( sorted_send_actors_stage2, actor_mappings_stage2, max_workers, requires_grad, group_name=group_name, stage2=True, filter_fn=filter_fn, param_group=param_group) + os.environ["log_debug"] = "0" + e_time = time.time() + logger.info(f"debug param sync cost stage2: {e_time - s_time}") def _multi_thread_sync_for_tp_num_mapping_eq_1( self, send_actors_list:List, actor_mappings_list:List, @@ -1054,6 +1196,7 @@ def _multi_thread_sync_for_tp_num_mapping_eq_1( actor_mappings = actor_mappings_list[0] sorted_send_actors = self.sort_send_actors(actor_mappings, send_actors) + logger.info(f"20250131 sorted_send_actors: {sorted_send_actors}") max_workers = self._calculate_max_workers(sorted_send_actors, actor_mappings) with ThreadPoolExecutor(max_workers=max_workers) as executor: @@ -1186,8 +1329,18 @@ def setup_rank_mapping(self): else: self.build_rank_mapping_for_ep() elif self.tp_num_mapping > 1: - self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors - self.build_rank_mapping_two_stage() + if self.hep_num_mapping == 1: + logger.info(f"20250131 build ranking mapping for ep") + self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors + # self.send_actors_to_regroup_routed_experts = self.send_actors_to_regroup_routed_experts[0] + # self.sorted_send_actors_for_routed_experts = self.send_actors_to_regroup_routed_experts + logger.info(f"20250131 build ranking mapping for routed expert") + self.build_rank_mapping_for_routed_experts() + logger.info(f"20250131 build ranking mapping for params except routed expert") + self.build_rank_mapping_for_params_except_routed_expert() + else: + self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors + self.build_rank_mapping_two_stage() else: raise NotImplementedError( f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})" @@ -1221,7 +1374,7 @@ def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): return assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ - f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" + f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: @@ -1278,7 +1431,7 @@ def split_ranks_by_ep_and_tp_size(ranks, (src_replica_ranks2offset[tuple(src_replica_ranks)] + len_dst_div_src) % len(src_ep_and_tp_group) ) - if self._debug: + if True:#self._debug: def debug_msg_for_actor_mappings(actor_mapping): if actor_mapping is None: return @@ -1290,12 +1443,14 @@ def debug_msg_for_actor_mappings(actor_mapping): debug_msg_for_actor_mappings(self.send_recv_actor_mappings) debug_msg_for_actor_mappings(self.send_recv_actor_mappings_for_routed_experts) + count = 0 for regroup_actors in self.send_actors_to_regroup_routed_experts: + count += 1 cat_str = "_".join(str(self.actor2rank[actor]) for actor in regroup_actors) - logger.debug(f"{self._comm_type_to_regroup_routed_experts} actors: {cat_str}") + logger.info(f"{self._comm_type_to_regroup_routed_experts} actors_{count}: {cat_str}") for k, v_list in self.send_recv_actor_mappings.items(): for v in v_list: - logger.debug(f"send_recv_actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}") + logger.info(f"send_recv_actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}") def add_recv_actor_for_routed_experts(self, src_rank, dst_rank): src_actor = self.src_model.get_actor(src_rank) @@ -1391,6 +1546,7 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): if self.concurrent_comm: assert self.dst_model.use_vllm_backend + s_time = time.time() max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts) if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER: # allgather routed experts only @@ -1407,9 +1563,14 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): max_workers=max_workers, requires_grad=requires_grad, filter_fn=self.routed_experts_filter) + e_time = time.time() + logger.info(f"debug param sync: complete to alltoall router experts. ") + logger.info(f"debug param sync cost alltoall {e_time-s_time}") # sync everything to inference model if self.tp_num_mapping == 1: + logger.info(f"debug self.tp_num_mapping: {self.tp_num_mapping}") + s_time = time.time() send_actors_list = [self.sorted_send_actors] actor_mappings_list = [self.send_recv_actor_mappings] self._multi_thread_sync_for_tp_num_mapping_eq_1( @@ -1419,16 +1580,53 @@ def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False): filter_fn=None, param_group="default" ) + e_time = time.time() + logger.info(f"debug param sync: complete to alltoall router experts") + elif self.tp_num_mapping > 1: - send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] - actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] - self._multi_thread_sync_for_tp_num_mapping_gt_1( - send_actors_list, - actor_mappings_list, - requires_grad=requires_grad, - filter_fn=None, - param_group="default" + logger.info(f"debug self.tp_num_mapping: {self.tp_num_mapping}") + s_time = time.time() + logger.info(f"debug param sync: start to sync other weights.") + + # send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] + # actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] + # self._multi_thread_sync_for_tp_num_mapping_gt_1( + # send_actors_list, + # actor_mappings_list, + # requires_grad=requires_grad, + # filter_fn=None, + # param_group="default" + # ) + # First, synchronize routed experts. + self._synchronize_routed_experts(requires_grad=requires_grad, validate=validate) + + self.clear_cache( + sorted_send_actors_list = [ + self.send_actors_to_regroup_routed_experts, + self.sorted_send_actors_for_routed_experts + ], + rank_mapping_list=[ + self.send_recv_actor_mappings_for_routed_experts + ] + ) + + # Then, synchronize parameters except routed experts + self._synchronize_params_except_routed_experts(requires_grad=requires_grad, validate=validate) + + self.reset_synchronizer() + + self.clear_cache( + sorted_send_actors_list = [ + self.sorted_send_actors, + self.sorted_send_actors_stage2, + ], + rank_mapping_list = [ + self.send_recv_actor_mappings, + self.send_recv_actor_mappings_stage2 + ] ) + e_time = time.time() + logger.info(f"debug param sync: complete to sync other weights.") else: raise NotImplementedError( f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})" @@ -1457,8 +1655,15 @@ def _synchronize_routed_experts(self, requires_grad=None, validate=False): send_actors_list : List = [] actor_mappings_list : List = [] if self.concurrent_comm: + send_actors_list = [self.sorted_send_actors_for_routed_experts] actor_mappings_list = [self.send_recv_actor_mappings_for_routed_experts] + + logger.info(f"send_actors_list: {send_actors_list}") + logger.info(f"actor_mappings_list: {actor_mappings_list}") + # import pdb;pdb.set_trace() + # logger.info(f"sync routed experts from {[self.actor2rank[ele] for ele in send_actors_list[0]]}") + self._multi_thread_sync_for_tp_num_mapping_eq_1( send_actors_list, actor_mappings_list, diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index c4630cf8..a8f3257a 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -303,7 +303,7 @@ class RuntimeConfig(BaseConfig): #: profiler dir profiler_dir: str = None #: coalesce_buffer size in mb - coalesced_buffer_mb: int = 100 + coalesced_buffer_mb: int = 1024 #: concurrent parameter sync concurrent_comm: bool = True #: parameter sync communication type, broadcast/p2p diff --git a/chatlearn/utils/constant.py b/chatlearn/utils/constant.py index 4ee59907..2bae4dc7 100644 --- a/chatlearn/utils/constant.py +++ b/chatlearn/utils/constant.py @@ -39,6 +39,7 @@ class VLLMVersion(str, Enum): v_0_3_0 = "0.3.0" v_0_5_1 = "0.5.1" v_0_6_3 = "0.6.3" + v_0_6_6 = "0.6.6" class QwenVersion(float, Enum): diff --git a/chatlearn/utils/vllm_import_helper.py b/chatlearn/utils/vllm_import_helper.py index 71864503..d618bf7a 100644 --- a/chatlearn/utils/vllm_import_helper.py +++ b/chatlearn/utils/vllm_import_helper.py @@ -47,8 +47,8 @@ from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.detokenizer import Detokenizer -elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: - # imports for vllm-063 +elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: + # imports for vllm-063/-66 from vllm.core.interfaces import BlockSpaceManager from vllm.distributed import parallel_state from vllm.distributed.communication_op import tensor_model_parallel_all_gather @@ -56,7 +56,8 @@ from vllm.distributed.parallel_state import initialize_model_parallel from vllm.distributed.utils import get_pp_indices from vllm.engine.async_llm_engine import _AsyncLLMEngine as LLMEngine - from vllm.engine.llm_engine import _load_generation_config_dict + if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + from vllm.engine.llm_engine import _load_generation_config_dict from vllm.engine.llm_engine import SchedulerContext, SchedulerOutputState from vllm.engine.output_processor.interfaces import SequenceGroupOutputProcessor from vllm.engine.output_processor.stop_checker import StopChecker @@ -101,7 +102,7 @@ def get_block_manager_cls(version): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: return BlockSpaceManager - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: return BlockSpaceManager.get_block_space_manager_class(version) @@ -110,7 +111,7 @@ def get_model_architecture(config): from vllm.model_executor.model_loader import _get_model_architecture as get_model_architecture_v1 return get_model_architecture_v1(config) - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: from vllm.model_executor.model_loader.utils import get_model_architecture as get_model_architecture_v2 return get_model_architecture_v2(config)[0] @@ -119,7 +120,7 @@ def get_pipeline_model_parallel_rank(): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: return parallel_state.get_pipeline_model_parallel_rank() - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: return parallel_state.get_pp_group().rank_in_group @@ -127,5 +128,5 @@ def get_pipeline_model_parallel_world_size(): if CURRENT_VLLM_VERSION == VLLMVersion.v_0_3_0: return parallel_state.get_pipeline_model_parallel_world_size() - elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + elif CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: return parallel_state.get_pp_group().world_size diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index d6622d0c..1bedee15 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -608,7 +608,7 @@ def _init_distributed_environment(args): world_size=args.world_size, rank=args.rank, timeout=timedelta(minutes=args.distributed_timeout_minutes)) - if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3]: + if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_5_1, VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: _WORLD = None if _WORLD is None: ranks = list(range(torch.distributed.get_world_size())) @@ -874,7 +874,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Transformer Layers print("Converting transformer layers") - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: start_layer_idx, _ = get_pp_indices( hf_config.num_hidden_layers, pp_rank, @@ -1239,7 +1239,7 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version= # Transformer Layers print("Converting transformer layers") - if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: + if CURRENT_VLLM_VERSION in [VLLMVersion.v_0_6_3, VLLMVersion.v_0_6_6]: start_layer_idx, _ = get_pp_indices( hf_config.num_hidden_layers, pp_rank, diff --git a/examples/megatron/models/vllm_policy_inference.py b/examples/megatron/models/vllm_policy_inference.py index 4c2402de..42b255d1 100644 --- a/examples/megatron/models/vllm_policy_inference.py +++ b/examples/megatron/models/vllm_policy_inference.py @@ -138,6 +138,7 @@ def decode_internal(self, batched_outputs): prompt_sizes = torch.tensor([len(q) for q in no_padded_query_ids], device=all_tokens.device) loss_mask = get_loss_mask(all_tokens, self.tokenizer.tokenizer.eos_token_id, prompt_sizes) loss_mask = loss_mask.to("cpu") + print(f"str_outputs: {len(str_outputs)} {str_outputs}") return {"all_tokens": all_tokens, "str_outputs": str_outputs, "str_prompts": str_prompts, "no_padded_query_ids": no_padded_query_ids, "logprobs": logprobs, "loss_mask": loss_mask}