From 8572a2451be519fc5fd6d1350b1cefcc8317cfed Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Tue, 17 Dec 2024 20:49:38 +0800 Subject: [PATCH 1/4] support megatron ckpt for llama. --- chatlearn/models/vllm/hooks/loader.py | 3 +++ chatlearn/utils/vllm_utils.py | 18 ++++++++++-------- .../configs/llama2/vllm_policy_inference.yaml | 3 +++ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/loader.py index 4a4d06c5..e35f8077 100644 --- a/chatlearn/models/vllm/hooks/loader.py +++ b/chatlearn/models/vllm/hooks/loader.py @@ -22,6 +22,7 @@ from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models import llama from vllm.model_executor.models import qwen2 from chatlearn.utils.vllm_import_helper import LlamaForCausalLM @@ -87,6 +88,8 @@ def load_model(self, *, model_config, if self.load_config.model_loader_extra_config.get("need_load_ckpt", True): qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict qwen2.Qwen2ForCausalLM.load_weights = load_weights + llama.LlamaForCausalLM.load_state_dict = load_state_dict + llama.LlamaForCausalLM.load_weights = load_weights model.load_weights(self.load_config.model_loader_extra_config) else: # For accurate performance evaluation, we assign diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index c04880b0..a48bb01a 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -812,6 +812,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Convert. print("Start to convert...") + prefix_name = "model" if is_vllm_v2() else "model.model" # Embeddings print("Converting embeddings") @@ -838,7 +839,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version elif word_embeddings is not None: # After training with megatron, word_embeddings is stored differently word_embeddings = word_embeddings.to(hf_config.torch_dtype) - output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings # Reset the vocab size hf_config.vocab_size = word_embeddings.shape[0] @@ -879,7 +880,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version # Is it a weight or a bias? weight_or_bias = m.group(3) # The name of the layer. - layer_name = f"model.model.layers.{layer_idx}" + layer_name = f"{prefix_name}.layers.{layer_idx}" params = val.to(hf_config.torch_dtype) @@ -935,7 +936,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version if "final_norm.weight" in params or "final_layernorm.weight" in params: print("Converting final layernorm") final_norm_weight = params["final_norm.weight"] if "final_norm.weight" in params else params["final_layernorm.weight"] - output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) # For LM head, transformers' wants the matrix to weight embeddings. params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight') @@ -943,7 +944,7 @@ def convert_llama_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version assert not params, "weight name of lm_head expect 'model.language_model.output_layer.weight'." elif params is not None: print("Converting LM head") - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype) # It should be done! print("Conversion from Megatron-LM to Transformers is done!") @@ -993,6 +994,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Convert. print("Start to convert...") + prefix_name = "model" if is_vllm_v2() else "model.model" # Embeddings print("Converting embeddings") @@ -1025,7 +1027,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Convert and store the word embeddings. word_embeddings = tp_state_dicts[tp_rank]['model'].get("embedding.word_embeddings.weight", None) word_embeddings = word_embeddings.to(hf_config.torch_dtype) - output_state_dict["model.model.embed_tokens.weight"] = word_embeddings + output_state_dict[f"{prefix_name}.embed_tokens.weight"] = word_embeddings # Reset the vocab size hf_config.vocab_size = word_embeddings.shape[0] @@ -1051,7 +1053,7 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # Is it a weight or a bias? weight_or_bias = layer_match_res.group(3) # The name of the layer - layer_name = f"model.model.layers.{layer_idx}" + layer_name = f"{prefix_name}.layers.{layer_idx}" params = val.to(hf_config.torch_dtype) @@ -1112,12 +1114,12 @@ def convert_llama_state_dict_from_mcore_to_vllm(args, hf_config, qwen_version=No # The final layernorm. print("Converting final layernorm") final_norm_weight = tp_state_dicts[0]['model'].get("decoder.final_layernorm.weight", None) - output_state_dict["model.model.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) + output_state_dict[f"{prefix_name}.norm.weight"] = final_norm_weight.to(hf_config.torch_dtype) # For LM head, transformers' wants the matrix to weight embeddings. print("Converting LM head") params = tp_state_dicts[tp_rank]['model'].get('output_layer.weight', None) - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict["lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"] = params.to(hf_config.torch_dtype) # It should be done! print("Conversion from Megatron-Core to Transformers is done!") diff --git a/examples/megatron/configs/llama2/vllm_policy_inference.yaml b/examples/megatron/configs/llama2/vllm_policy_inference.yaml index 1ad565fa..c1f3f4c0 100644 --- a/examples/megatron/configs/llama2/vllm_policy_inference.yaml +++ b/examples/megatron/configs/llama2/vllm_policy_inference.yaml @@ -45,3 +45,6 @@ vllm_prompt_key: prompt tensor_model_parallel_size: ${policy_tp} pipeline_model_parallel_size: ${policy_pp} + +enforce_eager: ${enforce_eager:False} +vllm_load_format: ${vllm_load_format:dummy} From deb85d9819feb876fa933d1e1380cb9224498593 Mon Sep 17 00:00:00 2001 From: "jiangle.jl" Date: Thu, 19 Dec 2024 07:11:47 +0000 Subject: [PATCH 2/4] fix pylint. --- chatlearn/models/vllm/hooks/__init__.py | 2 + .../models/vllm/hooks/async_llm_engine.py | 55 ++++++++ chatlearn/models/vllm/hooks/llm.py | 89 ++++++++++++ chatlearn/models/vllm_module_v2.py | 132 +++++++----------- chatlearn/runtime/decorator.py | 4 +- 5 files changed, 200 insertions(+), 82 deletions(-) create mode 100644 chatlearn/models/vllm/hooks/async_llm_engine.py create mode 100644 chatlearn/models/vllm/hooks/llm.py diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 2a4036bc..9bedfe50 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -25,6 +25,8 @@ from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: from chatlearn.models.vllm.hooks import input_preprocess + from chatlearn.models.vllm.hooks import async_llm_engine + from chatlearn.models.vllm.hooks import llm from chatlearn.models.vllm.hooks import loader else: if importlib.util.find_spec("vllm"): diff --git a/chatlearn/models/vllm/hooks/async_llm_engine.py b/chatlearn/models/vllm/hooks/async_llm_engine.py new file mode 100644 index 00000000..ee84f581 --- /dev/null +++ b/chatlearn/models/vllm/hooks/async_llm_engine.py @@ -0,0 +1,55 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 del init_ray_cluster in AsyncLLMEngine.""" + +from typing import Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.config import EngineConfig +from vllm.engine import async_llm_engine +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.metrics_types import StatLoggerBase +from vllm.usage.usage_lib import UsageContext + +@classmethod +def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + engine_config: Optional[EngineConfig] = None, + start_engine_loop: bool = True, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, +) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + print(f"debug aaaaa async_llm_engine.") + if engine_config is None: + engine_config = engine_args.create_engine_config() + + executor_class = cls._get_executor_cls(engine_config) + + # Create the async LLM engine. + engine = cls( + **engine_config.to_dict(), + executor_class=executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + start_engine_loop=start_engine_loop, + usage_context=usage_context, + stat_loggers=stat_loggers, + ) + return engine + +async_llm_engine.AsyncLLMEngine.from_engine_args = from_engine_args diff --git a/chatlearn/models/vllm/hooks/llm.py b/chatlearn/models/vllm/hooks/llm.py new file mode 100644 index 00000000..e597e096 --- /dev/null +++ b/chatlearn/models/vllm/hooks/llm.py @@ -0,0 +1,89 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 llm init with AsyncLLMEngine and AsyncEngineArgs.""" + +from typing import Any, Dict, Optional + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints import llm +from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter + +def init( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: float = 4, + cpu_offload_gb: float = 0, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, +) -> None: + ''' + LLM constructor. + + Note: if enforce_eager is unset (enforce_eager is None) + it defaults to False. + ''' + print(f"debug aaaaa llm") + + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, + mm_processor_kwargs=mm_processor_kwargs, + **kwargs, + ) + + self.llm_engine = AsyncLLMEngine.from_engine_args( + engine_args, usage_context=UsageContext.LLM_CLASS).engine + self.request_counter = Counter() + +llm.LLM.__init__ = init diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index 2daf2cc4..34a803da 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -18,6 +18,7 @@ import inspect import os import sys +import json from typing import Optional import torch @@ -25,10 +26,12 @@ from vllm import SamplingParams from vllm.config import EngineConfig from vllm.config import LoadFormat +from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import RayWorkerWrapper from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs from vllm.usage.usage_lib import UsageContext +from vllm.utils import Counter from vllm.utils import FlexibleArgumentParser from chatlearn.utils.global_vars import set_vllm_actors @@ -46,7 +49,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.engine = None + self.llm_engine = None self.tokenizer = None def setup(self): @@ -56,87 +59,47 @@ def setup(self): tokenizer.tokenizer = tokenizer self.tokenizer = tokenizer - def _init_args(self, args): - # scheduler config - args.max_num_seqs = self.module_args.generation_batch_size - args.max_num_batched_tokens = self.model_args.get("max_num_batched_tokens") - args.num_scheduler_steps = self.model_args.get("num_scheduler_steps", 1) - - # model config - args.max_seq_len = self.model_args.get("seq_length") - - # logger config - args.disable_log_requests = True - - # load format: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. - args.load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) - if args.load_format == LoadFormat.DUMMY: - args.model_loader_extra_config = self.model_args - self.model_args["need_load_ckpt"] = self.src_parameter_model is None - - # engine config - args.enforce_eager = self.model_args.get("enforce_eager", False) - def setup_vllm(self, workers): # setup vllm engine in rank 0 os.environ['VLLM_HOST_IP'] = self.get_address() set_vllm_actors(workers) - parser = FlexibleArgumentParser() - parser = AsyncEngineArgs.add_cli_args(parser) - backup_sys_argv = sys.argv + dtype = self.model_args.get("dtype", "bfloat16") if self.model_args.get("fp16", False): dtype = "float16" - vllm_sys_argv = ["", - f"--model={self.model_args['tokenizer']}", - f"--tensor_parallel_size={self.module_args.tensor_model_parallel_size}", - f"--pipeline_parallel_size={self.module_args.pipeline_model_parallel_size}", - f"--dtype={dtype}", - "--worker_use_ray", - "--disable_custom_all_reduce"] - sys.argv = vllm_sys_argv - args = parser.parse_args() - self._init_args(args) - engine_args = AsyncEngineArgs.from_cli_args(args) - self.engine = self.from_engine_args(engine_args) - - sys.argv = backup_sys_argv - self.tokenizer = self.engine.engine.tokenizer - def from_engine_args( - self, - engine_args: AsyncEngineArgs, - engine_config: Optional[EngineConfig] = None, - start_engine_loop: bool = True, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers = None, - ) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - - if engine_config is None: - engine_config = engine_args.create_engine_config() - - executor_class = AsyncLLMEngine._get_executor_cls(engine_config) - - # Create the async LLM engine. - engine = AsyncLLMEngine( - **engine_config.to_dict(), - executor_class=executor_class, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - start_engine_loop=start_engine_loop, - usage_context=usage_context, - stat_loggers=stat_loggers, - ) - return engine - - async def generate_one_sample(self, prompt, sampling_param, request_id): - results_generator = self.engine.generate(prompt, sampling_param, request_id) - final_output = None - async for request_output in results_generator: - final_output = request_output - return final_output + load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) + if load_format == LoadFormat.DUMMY: + self.model_args["need_load_ckpt"] = self.src_parameter_model is None + model_loader_extra_config = self.model_args + else: + model_loader_extra_config = None + + self.llm = LLM( + model=self.model_args['tokenizer'], + tokenizer=self.model_args['tokenizer'], + max_seq_len_to_capture=self.model_args.get("seq_length"), + # load model: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. + load_format=load_format, + model_loader_extra_config=model_loader_extra_config, + # parallelism strategy + tensor_parallel_size=self.module_args.tensor_model_parallel_size, + pipeline_parallel_size=self.module_args.pipeline_model_parallel_size, + dtype=dtype, + # scheduling strategy + max_num_seqs=self.module_args.generation_batch_size, + max_num_batched_tokens = self.model_args.get("max_num_batched_tokens", None), + num_scheduler_steps=self.model_args.get("num_scheduler_steps", 1), + gpu_memory_utilization=self.model_args.get("gpu_memory_utilization", 0.90), + # logger + disable_log_requests=self.model_args.get("disable_log_requests", True), + disable_log_stats=self.model_args.get("disable_log_stats", True), + trust_remote_code=True, + # TODO(jiangle.jl): support non-eager mode. + enforce_eager=True, + disable_custom_all_reduce=True, + distributed_executor_backend="ray") + self.tokenizer = self.llm.llm_engine.tokenizer def _get_sampling_params(self, is_eval): temperature = 0.0 @@ -173,7 +136,7 @@ def _get_sampling_params(self, is_eval): sampling_params.use_beam_search = self.model_args.get("use_beam_search") return sampling_params - def convert_v1_inputs(self, prompts, prompt_token_ids): + def _convert_v1_inputs(self, prompts, prompt_token_ids): num_requests = len(prompts) assert num_requests == len(prompt_token_ids), \ ("The lengths of prompts and prompt_token_ids must be the same.") @@ -193,6 +156,9 @@ def convert_v1_inputs(self, prompts, prompt_token_ids): return inputs + async def generate_all(self, prompts, sampling_params): + pass + async def generate_vllm(self, query, is_eval): prompt_key = self.model_args.get("vllm_prompt_key", "prompt") input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids") @@ -202,6 +168,8 @@ async def generate_vllm(self, query, is_eval): seq_len = self.model_args.get("seq_length") final_outputs = [] tasks = [] + parsed_prompts = [] + sampling_params = [] for i, prompt in enumerate(prompts): request_id = i prompt_token_ids = prompts_token_ids[i] @@ -215,14 +183,18 @@ async def generate_vllm(self, query, is_eval): max_tokens = self.model_args.get("max_new_tokens") assert max_tokens < seq_len, "max_new_tokens must less than seq length." sampling_param.max_tokens = max_tokens - inputs = self.convert_v1_inputs( + item = self._convert_v1_inputs( prompts=[prompt], prompt_token_ids=[prompt_token_ids], )[0] + parsed_prompts.append(item) + sampling_params.append(sampling_param) - task = asyncio.create_task(self.generate_one_sample(inputs, sampling_param, request_id)) - tasks.append(task) - outputs = await asyncio.gather(*tasks) + outputs = self.llm.generate( + parsed_prompts, + sampling_params, + use_tqdm=True, + ) final_outputs = sorted(outputs, key=lambda x: int(x.request_id)) return final_outputs @@ -253,7 +225,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() - self.engine = None + self.llm_engine = None def peak_memory(self): """ diff --git a/chatlearn/runtime/decorator.py b/chatlearn/runtime/decorator.py index 0291ee76..56d89972 100644 --- a/chatlearn/runtime/decorator.py +++ b/chatlearn/runtime/decorator.py @@ -164,7 +164,7 @@ def get_kwarg(key): # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ - or isinstance(self, VLLMModuleV2): + or isinstance(self, VLLMModuleV2): final_results = concat_along_batch(results) else: if 'iteration' in inspect.signature(func).parameters: @@ -176,7 +176,7 @@ def get_kwarg(key): # for model with DP/EP, we need to return results from all ranks # for model with TP/PP, only return the results from last rank if self.is_last_rank() or self.data_parallel_size is None or self.data_parallel_size > 1 \ - or isinstance(self, VLLMModuleV2): + or isinstance(self, VLLMModuleV2): final_results = ret else: if 'iteration' in inspect.signature(func).parameters: From e339fc3139321b9da0d31477f7eee66177cc79e8 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 22 Jan 2025 19:02:31 +0800 Subject: [PATCH 3/4] fix[benchmark_vllm]: avoid num input tokens overflow --- examples/tests/benchmark_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tests/benchmark_vllm.py b/examples/tests/benchmark_vllm.py index 80c8697e..6a0a769d 100644 --- a/examples/tests/benchmark_vllm.py +++ b/examples/tests/benchmark_vllm.py @@ -70,7 +70,7 @@ def example_to_requests(args, tokenizer, examples): if args.max_tokens is not None: max_tokens = args.max_tokens else: - max_tokens = 2048 - len(tokenizer(prompt).input_ids) + max_tokens = 2048 - min(len(tokenizer(prompt).input_ids), 1024) requests.append(SampleRequest(prompt, prompt_len, max_tokens)) sampling_params.append( From 896dd0409b7f1a04a0dcdc57d41c5fd7950078e4 Mon Sep 17 00:00:00 2001 From: "baodong.lh" Date: Wed, 22 Jan 2025 19:07:12 +0800 Subject: [PATCH 4/4] revert: undo changes in vllm_module_v2.py --- chatlearn/models/vllm_module_v2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index de0c2eac..072d9c24 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs): if 'worker_module_name' in kwargs and 'worker_class_name' in kwargs: RayWorkerWrapper.__init__(self, **kwargs) # pylint: disable=non-parent-init-called os.environ['VLLM_HOST_IP'] = self.get_address() + self.tokenizer = None self._model = None self.llm = None