Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion chatlearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from chatlearn.runtime.engine import Environment
from chatlearn.runtime.engine import Trainer
from chatlearn.runtime.evaluator import Evaluator
from chatlearn.runtime.model_flow import ControlDependencies
from chatlearn.utils.future import get
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.logger import logger
Expand Down
65 changes: 27 additions & 38 deletions chatlearn/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def get_merged_data(self,
micro_batch_index: Optional[int] = None, # pylint: disable-next=unused-argument
model_node: Optional[ModelNode] = None, # pylint: disable-next=unused-argument
trainable: bool = False
):
):
"""
merge data from different queues, get data from queues for current node by dp-wise.
It will be executed in form of a for loop for dp size-times.
Expand Down Expand Up @@ -215,6 +215,8 @@ def rebatch_all_merged_data(self, model_node: ModelNode, in_queues: List[Queue],
out_queues = [None] * len(in_queues)
# in environment, num_consumers is global dp size(replica_dp_size*dp_size)
num_consumers = self.global_dp_size(model_node.model)
#! Here indicates if model_node has input nodes, len(in_queues) == len(model_node.input_nodes)
#! In other words, self._input_queue will be None for all model node with input nodes
for index, (input_node, in_queue) in enumerate(zip(model_node.input_nodes, in_queues)):
num_producers = self.global_dp_size(input_node.model)
if num_producers == num_consumers:
Expand Down Expand Up @@ -254,32 +256,19 @@ def rebatch_all_merged_data(self, model_node: ModelNode, in_queues: List[Queue],
INDEX_TAG: (q_idx % division, division)}))
return out_queues

def get_next_data(self, in_queue: Union[Queue, List[Queue]], model_node: ModelNode, micro_batch_index):
def get_next_data(self, in_queue_list: List[Queue], model_node: ModelNode, micro_batch_index):
"""
get a dp-rank data
"""
if isinstance(in_queue, list):
if len(in_queue) > 0:
# this should happen for inference models, will trigger bug for training models
# since training models accept a list of remote object, which has the same
# behavior for models accept multiple inputs
# we need to deal with it later
assert not model_node.trainable
data = self.get_merged_data(in_queue, micro_batch_index=micro_batch_index,
model_node=model_node, trainable=model_node.trainable)
mb, query = decode_data(data)
else:
mb, query = micro_batch_index, []
else:
data = self.get_merged_data([in_queue], micro_batch_index=micro_batch_index,
model_node=model_node, trainable=model_node.trainable)
assert len(data['data']) == 1
data['data'] = data['data'][0]
mb, query = decode_data(data)
query = [query]
return mb, query

def generate_step_one_model_internal(self, model_node, in_queue, step_num, replica, func_name="forward_step", to_empty_cache=None,
data = self.get_merged_data(
in_queue_list,
micro_batch_index=micro_batch_index,
model_node=model_node,
trainable=model_node.trainable
)
return decode_data(data)

def generate_step_one_model_internal(self, model_node, in_queue_list: List[Queue], step_num, replica, func_name="forward_step", to_empty_cache=None,
is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None):

"""
Expand All @@ -301,28 +290,28 @@ def generate_step_one_model_internal(self, model_node, in_queue, step_num, repli

if isinstance(replica.model, (VLLMModule, SGLangModule)):
# for rollout we only to pass data to engine for every replica
mb, query = self.get_next_data(in_queue, model_node, micro_batch_index)
mb, query = self.get_next_data(in_queue_list, model_node, micro_batch_index)
assert isinstance(query, list)
ret = replica.call_actor_remote_func(replica.engine, func_name, *query, **kwargs)
# output length is num replica
output.append((ret, mb))
else:
for _, actors in replica.dp_rank_to_actors.items():
mb, query = self.get_next_data(in_queue, model_node, micro_batch_index)
mb, query = self.get_next_data(in_queue_list, model_node, micro_batch_index)
assert isinstance(query, list)
for actor in actors:
ret = replica.call_actor_remote_func(actor, func_name, *query, **kwargs)
# output length is num actor
output.append((ret, mb))
return output

def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step_num, func_name="forward_step",
def generate_step_one_model(self, model_node, replica, in_queue_list: List[Queue], out_queue, step_num, func_name="forward_step",
to_empty_cache=None, is_eval=False, to_onload=None, to_offload=None, micro_batch_index=None):
"""
forward for a model replica, and only set the output of last rank in dp rank to out_queue
"""
# output is a list of tuple, each tuple is (remote_refs, mb)
output = self.generate_step_one_model_internal(model_node, in_queue, step_num, replica, func_name, to_empty_cache,
output = self.generate_step_one_model_internal(model_node, in_queue_list, step_num, replica, func_name, to_empty_cache,
is_eval, to_onload, to_offload, micro_batch_index)

# for get the data in last actor of a dp rank
Expand All @@ -349,14 +338,12 @@ def generate_step_one_model(self, model_node, replica, in_queue, out_queue, step
remote_refs = [item[0] for item in output]
return out_queue, remote_refs

def regroup_inqueue(self, model_node: ModelNode, queues, is_eval=False):
def regroup_inqueue(self, model_node: ModelNode, queues: List[Queue], is_eval=False):
"""
re-construct input_queues[node_num, previous_node_global_dp_size] to output_queues[node_num, current_node_global_dp_size]
"""
if self.args.policy_to_regroup_queue == "global_barrier":
# barrier to regroup all queues of producer node
if not isinstance(queues, list):
queues = [queues]
logger.info(f"{LOG_START} regroup_inqueue in_queue {model_node}: {[ele.qsize() for ele in queues]}")
out_queues = self.rebatch_all_merged_data(model_node, queues, is_eval=is_eval)
logger.info(f"{LOG_START} regroup_inqueue out_queues {model_node}: {[ele.qsize() for ele in out_queues]}")
Expand All @@ -372,6 +359,7 @@ def compute_loop_one_model(self, model_node: ModelNode, num_batch=None):
model = model_node.model
is_eval = self.is_eval

# NOTE: Each `batch` is a batch of data for a replica to compute one time
if num_batch is None:
num_batch = self.num_iteration(model)

Expand All @@ -382,22 +370,23 @@ def compute_loop_one_model(self, model_node: ModelNode, num_batch=None):
logger.info(f"{LOG_START} complete to wait colocate models to finish for {model_node}")
replica_num = len(model.replicas)
last_step_start = max(num_batch - replica_num, 0)
in_queue = model_node.get_input_queues()
in_queue_list: List[Queue] = model_node.get_input_queues()

logger.info(f"{LOG_START} start to regroup in_queue for {model_node}")
in_queue = self.regroup_inqueue(model_node, in_queue, is_eval=is_eval)
logger.info(f"{LOG_START} complete to regroup in_queue for {model_node}")
logger.info(f"{LOG_START} start to regroup in_queue_list for {model_node}")
in_queue_list: List[Queue] = self.regroup_inqueue(model_node, in_queue_list, is_eval=is_eval)
logger.info(f"{LOG_START} complete to regroup in_queue_list for {model_node}")

if isinstance(in_queue, list) and len(in_queue) == 1:
in_queue = in_queue[0]
results = []
logger.info(f"{LOG_START} start to generate_step_one_model for {model_node}")
for step in range(num_batch):
# NOTE: if onload and offload is required, for a model with K replicas and M batches,
# ONLOAD STEP: the first step of each replica, i.e., step < K
# OFFLOAD STEP: the final step of each replica, i.e., step >= max(M - K, 0)
to_empty_cache = step >= last_step_start and model.is_colocate
to_onload = step < replica_num and (model.is_colocate and model.enable_offload)
to_offload = step >= last_step_start and (model.is_colocate and model.enable_offload)
replica = self._next_model(model)
_, data = self.generate_step_one_model(model_node, replica, in_queue, model_node.out_queues, step, func_name, to_empty_cache,
_, data = self.generate_step_one_model(model_node, replica, in_queue_list, model_node.out_queues, step, func_name, to_empty_cache,
is_eval=is_eval, to_onload=to_onload, to_offload=to_offload)
results.append(data)
if model_node.next_colocate_node:
Expand Down
41 changes: 12 additions & 29 deletions chatlearn/runtime/model_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,16 @@
"""Model FLow"""

from collections import defaultdict, deque
from typing import List, Callable, Dict
from typing import List, Callable, Dict, Optional

from chatlearn.utils import future
from chatlearn.utils.global_vars import unwrap_func
from chatlearn.utils.global_vars import reset_dependencies, set_dependencies, get_dependencies
from chatlearn.utils.utils import flatten
from chatlearn.runtime.dist_actor import DistModel
from chatlearn.models.base_module import BaseModule
from .decorator import decorate_class_func


class ControlDependencies:
"""ControlDependencies"""

def __init__(self, dependencies):
if not isinstance(dependencies, list):
dependencies = [dependencies]
self.dependencies = dependencies

def __enter__(self):
set_dependencies(self.dependencies)
return self

def __exit__(self, exc_type, exc_value, traceback):
reset_dependencies()
from ray.util.queue import Queue


class DummyData:
Expand All @@ -60,9 +45,9 @@ def __init__(self, model: DistModel, func_name):
self.input_nodes = []
self.output_nodes = []
self.out_queues = None
self._input_queue = None
self._input_queue: Queue = None
# next colocate model node to execute
self.next_colocate_node = None
self.next_colocate_node: Optional[ModelNode] = None
# model to wait before the execution of current model
self.models_to_wait = []
# remote objects to wait before the execution of current model
Expand All @@ -85,15 +70,17 @@ def set_out_queues(self, queues):
def set_input_queue(self, queue):
self._input_queue = queue

def get_input_queues(self):
def get_input_queues(self) -> List[Queue]:
"""
Get all input queues of this model node. len(self.input_nodes) + 1 or
len(self.input_nodes) queues in total.
"""
input_queues = []
if self._input_queue is not None:
input_queues.append(self._input_queue)
for input_model_node in self.input_nodes:
out_index = input_model_node.output_nodes.index(self)
input_queues.append(input_model_node.out_queues[out_index])
if len(input_queues) == 1:
return input_queues[0]
return input_queues

def _find_all_parents(self, model, prev_models_results):
Expand Down Expand Up @@ -152,11 +139,11 @@ def __init__(self, cls):
self.input_consumers = []

def fake_compute(self, fn):
def inner(*args):
def inner(*args): #! self, *args
assert len(args) > 0
original_fn = unwrap_func(fn)
func_name = original_fn.__name__
model_node = ModelNode(args[0], func_name)
model_node = ModelNode(args[0], func_name) # args[0] == self
dist_model = self.name2remote_model[model_node.name]
model_node.model = dist_model
dist_model.model_node = model_node
Expand All @@ -166,10 +153,6 @@ def inner(*args):
data.to_nodes.append(model_node)
if data.from_node:
model_node.add_input_node(data.from_node)
dependencies = get_dependencies()
if dependencies is not None:
for dep in dependencies:
dep.from_node.dependent_output_nodes.append(model_node)
res = DummyData(model_node)
return res

Expand All @@ -194,7 +177,7 @@ def trace(self, models: List[DistModel], compute_flow: Callable):

dummy_data = DummyData()
assert compute_flow is not None
dummy_output = compute_flow(dummy_data)
dummy_output = compute_flow(dummy_data) # TODO: remove it? *args
# convert decorator back
for model in local_models:
for func_name in self.cls.model_to_call_funcs[model]:
Expand Down
11 changes: 10 additions & 1 deletion chatlearn/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,17 @@
import textwrap
import inspect
from collections import defaultdict
from typing import Dict, List, Union, Any

def encode_data(mb, data):
def encode_data(mb, data) -> Dict[str, Union[int, List[Any]]]:
"""
return a dict:
{
'iter': int,
'data': List[ObjectRef]
}

"""
return {"iter": mb, "data": data}


Expand Down
12 changes: 0 additions & 12 deletions chatlearn/utils/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
_EXIT_ACTOR_NAME = "ChatLearnExitActor"
_DECORATED_MODELS = None
_DECORATED_OUTER_TO_INNER = {}
_DEPENDENCIES = None
_VLLM_ACTORS = None


Expand Down Expand Up @@ -65,17 +64,6 @@ def set_wrap_func(func, new_func):
assert new_func not in _DECORATED_OUTER_TO_INNER
_DECORATED_OUTER_TO_INNER[new_func] = func

def set_dependencies(dependencies):
global _DEPENDENCIES
assert _DEPENDENCIES is None
_DEPENDENCIES = dependencies

def reset_dependencies():
global _DEPENDENCIES
_DEPENDENCIES = None

def get_dependencies():
return _DEPENDENCIES

def set_vllm_actors(actors):
global _VLLM_ACTORS
Expand Down
Loading