diff --git a/chatlearn/__init__.py b/chatlearn/__init__.py index b9073a86..26661539 100644 --- a/chatlearn/__init__.py +++ b/chatlearn/__init__.py @@ -26,7 +26,6 @@ from chatlearn.runtime.engine import Environment from chatlearn.runtime.engine import Trainer from chatlearn.runtime.evaluator import Evaluator - from chatlearn.runtime.model_flow import ControlDependencies from chatlearn.utils.future import get from chatlearn.utils.global_vars import get_args from chatlearn.utils.logger import logger diff --git a/chatlearn/runtime/executor.py b/chatlearn/runtime/executor.py index 73a1c304..91a81fd2 100644 --- a/chatlearn/runtime/executor.py +++ b/chatlearn/runtime/executor.py @@ -169,7 +169,7 @@ def get_merged_data(self, micro_batch_index: Optional[int] = None, # pylint: disable-next=unused-argument model_node: Optional[ModelNode] = None, # pylint: disable-next=unused-argument trainable: bool = False - ): + ): """ merge data from different queues, get data from queues for current node by dp-wise. It will be executed in form of a for loop for dp size-times. @@ -215,6 +215,8 @@ def rebatch_all_merged_data(self, model_node: ModelNode, in_queues: List[Queue], out_queues = [None] * len(in_queues) # in environment, num_consumers is global dp size(replica_dp_size*dp_size) num_consumers = self.global_dp_size(model_node.model) + #! Here indicates if model_node has input nodes, len(in_queues) == len(model_node.input_nodes) + #! In other words, self._input_queue will be None for all model node with input nodes for index, (input_node, in_queue) in enumerate(zip(model_node.input_nodes, in_queues)): num_producers = self.global_dp_size(input_node.model) if num_producers == num_consumers: @@ -254,32 +256,19 @@ def rebatch_all_merged_data(self, model_node: ModelNode, in_queues: List[Queue], INDEX_TAG: (q_idx % division, division)})) return out_queues - def get_next_data(self, in_queue: Union[Queue, List[Queue]], model_node: ModelNode, micro_batch_index): + def get_next_data(self, in_queue_list: List[Queue], model_node: ModelNode, micro_batch_index): """ get a dp-rank data """ - if isinstance(in_queue, list): - if len(in_queue) > 0: - # this should happen for inference models, will trigger bug for training models - # since training models accept a list of remote object, which has the same - # behavior for models accept multiple inputs - # we need to deal with it later - assert not model_node.trainable - data = self.get_merged_data(in_queue, micro_batch_index=micro_batch_index, - model_node=model_node, trainable=model_node.trainable) - mb, query = decode_data(data) - else: - mb, query = micro_batch_index, [] - else: - data = self.get_merged_data([in_queue], micro_batch_index=micro_batch_index, - model_node=model_node, trainable=model_node.trainable) - assert len(data['data']) == 1 - data['data'] = data['data'][0] - mb, query = decode_data(data) - query = [query] - return mb, query - - def generate_step_one_model_internal(self, model_node, in_queue, step_num, replica, func_name="forward_step", to_empty_cache=None, + data = self.get_merged_data( + in_queue_list, + micro_batch_index=micro_batch_index, + model_node=model_node, + trainable=model_node.trainable + ) + return decode_data(data) + + def generate_step_one_model_internal(self, model_node, in_queue_list: List[Queue], step_num, replica, func_name="forward_step", to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None): """ @@ -301,14 +290,14 @@ def generate_step_one_model_internal(self, model_node, in_queue, step_num, repli if isinstance(replica.model, (VLLMModule, SGLangModule)): # for rollout we only to pass data to engine for every replica - mb, query = self.get_next_data(in_queue, model_node, micro_batch_index) + mb, query = self.get_next_data(in_queue_list, model_node, micro_batch_index) assert isinstance(query, list) ret = replica.call_actor_remote_func(replica.engine, func_name, *query, **kwargs) # output length is num replica output.append((ret, mb)) else: for _, actors in replica.dp_rank_to_actors.items(): - mb, query = self.get_next_data(in_queue, model_node, micro_batch_index) + mb, query = self.get_next_data(in_queue_list, model_node, micro_batch_index) assert isinstance(query, list) for actor in actors: ret = replica.call_actor_remote_func(actor, func_name, *query, **kwargs) @@ -316,13 +305,13 @@ def generate_step_one_model_internal(self, model_node, in_queue, step_num, repli output.append((ret, mb)) return output - def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step_num, func_name="forward_step", + def generate_step_one_model(self, model_node, replica, in_queue_list: List[Queue], out_queue, step_num, func_name="forward_step", to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None): """ forward for a model replica, and only set the output of last rank in dp rank to out_queue """ # output is a list of tuple, each tuple is (remote_refs, mb) - output = self.generate_step_one_model_internal(model_node, in_queue, step_num, replica, func_name, to_empty_cache, + output = self.generate_step_one_model_internal(model_node, in_queue_list, step_num, replica, func_name, to_empty_cache, is_eval, to_onload, to_offload, micro_batch_index) # for get the data in last actor of a dp rank @@ -349,14 +338,12 @@ def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step remote_refs = [item[0] for item in output] return out_queue, remote_refs - def regroup_inqueue(self, model_node: ModelNode, queues, is_eval=False): + def regroup_inqueue(self, model_node: ModelNode, queues: List[Queue], is_eval=False): """ re-construct input_queues[node_num, previous_node_global_dp_size] to output_queues[node_num, current_node_global_dp_size] """ if self.args.policy_to_regroup_queue == "global_barrier": # barrier to regroup all queues of producer node - if not isinstance(queues, list): - queues = [queues] logger.info(f"{LOG_START} regroup_inqueue in_queue {model_node}: {[ele.qsize() for ele in queues]}") out_queues = self.rebatch_all_merged_data(model_node, queues, is_eval=is_eval) logger.info(f"{LOG_START} regroup_inqueue out_queues {model_node}: {[ele.qsize() for ele in out_queues]}") @@ -372,6 +359,7 @@ def compute_loop_one_model(self, model_node: ModelNode, num_batch=None): model = model_node.model is_eval = self.is_eval + # NOTE: Each `batch` is a batch of data for a replica to compute one time if num_batch is None: num_batch = self.num_iteration(model) @@ -382,22 +370,23 @@ def compute_loop_one_model(self, model_node: ModelNode, num_batch=None): logger.info(f"{LOG_START} complete to wait colocate models to finish for {model_node}") replica_num = len(model.replicas) last_step_start = max(num_batch - replica_num, 0) - in_queue = model_node.get_input_queues() + in_queue_list: List[Queue] = model_node.get_input_queues() - logger.info(f"{LOG_START} start to regroup in_queue for {model_node}") - in_queue = self.regroup_inqueue(model_node, in_queue, is_eval=is_eval) - logger.info(f"{LOG_START} complete to regroup in_queue for {model_node}") + logger.info(f"{LOG_START} start to regroup in_queue_list for {model_node}") + in_queue_list: List[Queue] = self.regroup_inqueue(model_node, in_queue_list, is_eval=is_eval) + logger.info(f"{LOG_START} complete to regroup in_queue_list for {model_node}") - if isinstance(in_queue, list) and len(in_queue) == 1: - in_queue = in_queue[0] results = [] logger.info(f"{LOG_START} start to generate_step_one_model for {model_node}") for step in range(num_batch): + # NOTE: if onload and offload is required, for a model with K replicas and M batches, + # ONLOAD STEP: the first step of each replica, i.e., step < K + # OFFLOAD STEP: the final step of each replica, i.e., step >= max(M - K, 0) to_empty_cache = step >= last_step_start and model.is_colocate to_onload = step < replica_num and (model.is_colocate and model.enable_offload) to_offload = step >= last_step_start and (model.is_colocate and model.enable_offload) replica = self._next_model(model) - _, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache, + _, data = self.generate_step_one_model(model_node, replica, in_queue_list, model_node.out_queues, step, func_name, to_empty_cache, is_eval=is_eval, to_onload=to_onload, to_offload=to_offload) results.append(data) if model_node.next_colocate_node: diff --git a/chatlearn/runtime/model_flow.py b/chatlearn/runtime/model_flow.py index e8c79e34..2b62e22e 100644 --- a/chatlearn/runtime/model_flow.py +++ b/chatlearn/runtime/model_flow.py @@ -15,31 +15,16 @@ """Model FLow""" from collections import defaultdict, deque -from typing import List, Callable, Dict +from typing import List, Callable, Dict, Optional from chatlearn.utils import future from chatlearn.utils.global_vars import unwrap_func -from chatlearn.utils.global_vars import reset_dependencies, set_dependencies, get_dependencies from chatlearn.utils.utils import flatten from chatlearn.runtime.dist_actor import DistModel from chatlearn.models.base_module import BaseModule from .decorator import decorate_class_func - -class ControlDependencies: - """ControlDependencies""" - - def __init__(self, dependencies): - if not isinstance(dependencies, list): - dependencies = [dependencies] - self.dependencies = dependencies - - def __enter__(self): - set_dependencies(self.dependencies) - return self - - def __exit__(self, exc_type, exc_value, traceback): - reset_dependencies() +from ray.util.queue import Queue class DummyData: @@ -60,9 +45,9 @@ def __init__(self, model: DistModel, func_name): self.input_nodes = [] self.output_nodes = [] self.out_queues = None - self._input_queue = None + self._input_queue: Queue = None # next colocate model node to execute - self.next_colocate_node = None + self.next_colocate_node: Optional[ModelNode] = None # model to wait before the execution of current model self.models_to_wait = [] # remote objects to wait before the execution of current model @@ -85,15 +70,17 @@ def set_out_queues(self, queues): def set_input_queue(self, queue): self._input_queue = queue - def get_input_queues(self): + def get_input_queues(self) -> List[Queue]: + """ + Get all input queues of this model node. len(self.input_nodes) + 1 or + len(self.input_nodes) queues in total. + """ input_queues = [] if self._input_queue is not None: input_queues.append(self._input_queue) for input_model_node in self.input_nodes: out_index = input_model_node.output_nodes.index(self) input_queues.append(input_model_node.out_queues[out_index]) - if len(input_queues) == 1: - return input_queues[0] return input_queues def _find_all_parents(self, model, prev_models_results): @@ -152,11 +139,11 @@ def __init__(self, cls): self.input_consumers = [] def fake_compute(self, fn): - def inner(*args): + def inner(*args): #! self, *args assert len(args) > 0 original_fn = unwrap_func(fn) func_name = original_fn.__name__ - model_node = ModelNode(args[0], func_name) + model_node = ModelNode(args[0], func_name) # args[0] == self dist_model = self.name2remote_model[model_node.name] model_node.model = dist_model dist_model.model_node = model_node @@ -166,10 +153,6 @@ def inner(*args): data.to_nodes.append(model_node) if data.from_node: model_node.add_input_node(data.from_node) - dependencies = get_dependencies() - if dependencies is not None: - for dep in dependencies: - dep.from_node.dependent_output_nodes.append(model_node) res = DummyData(model_node) return res @@ -194,7 +177,7 @@ def trace(self, models: List[DistModel], compute_flow: Callable): dummy_data = DummyData() assert compute_flow is not None - dummy_output = compute_flow(dummy_data) + dummy_output = compute_flow(dummy_data) # TODO: remove it? *args # convert decorator back for model in local_models: for func_name in self.cls.model_to_call_funcs[model]: diff --git a/chatlearn/runtime/utils.py b/chatlearn/runtime/utils.py index 8d5e5848..43e699fc 100644 --- a/chatlearn/runtime/utils.py +++ b/chatlearn/runtime/utils.py @@ -18,8 +18,17 @@ import textwrap import inspect from collections import defaultdict +from typing import Dict, List, Union, Any -def encode_data(mb, data): +def encode_data(mb, data) -> Dict[str, Union[int, List[Any]]]: + """ + return a dict: + { + 'iter': int, + 'data': List[ObjectRef] + } + + """ return {"iter": mb, "data": data} diff --git a/chatlearn/utils/global_vars.py b/chatlearn/utils/global_vars.py index be0e0748..947a2619 100644 --- a/chatlearn/utils/global_vars.py +++ b/chatlearn/utils/global_vars.py @@ -19,7 +19,6 @@ _EXIT_ACTOR_NAME = "ChatLearnExitActor" _DECORATED_MODELS = None _DECORATED_OUTER_TO_INNER = {} -_DEPENDENCIES = None _VLLM_ACTORS = None @@ -65,17 +64,6 @@ def set_wrap_func(func, new_func): assert new_func not in _DECORATED_OUTER_TO_INNER _DECORATED_OUTER_TO_INNER[new_func] = func -def set_dependencies(dependencies): - global _DEPENDENCIES - assert _DEPENDENCIES is None - _DEPENDENCIES = dependencies - -def reset_dependencies(): - global _DEPENDENCIES - _DEPENDENCIES = None - -def get_dependencies(): - return _DEPENDENCIES def set_vllm_actors(actors): global _VLLM_ACTORS