Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 136 additions & 27 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from chatlearn.data.sampler import SingleDataSampler, EpisodeDataSampler
from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.utils import future
from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense
from chatlearn.utils.dist_utils import bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage
from chatlearn.utils.dist_utils import bucket_tensors, bucket_tensor_generator, coalesced_comm_dense
from chatlearn.utils.dist_utils import bucket_tensors_two_stage, bucket_tensors_two_stage_generator, coalesced_comm_dense_two_stage
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.global_vars import set_global_variables
from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger
Expand Down Expand Up @@ -766,10 +766,8 @@ def allgather_routed_expert_parameter(self, group_name, pipe_stage=0):
self._expert_sync_buffer.pop(name, "Not Found.")
self._expert_sync_buffer[name] = param

def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
"""
:meta private:
"""
def _broadcast_parameter_opt_level_0(self, rank, src_rank, group_name, pipe_stage=0):
debug_rank_0(">>>>>>>>>>>>>>>>broadcast parameter at memory optimization level 0")
tensors = []
for name, param in self._parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer and \
Expand All @@ -788,32 +786,113 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):

for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)
self.empty_cache()

def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
"""
Arguments:
to_rank: receive rank in mapping from trainer to inference model.
buffer_rank: index which tensors of sync buffer to be sended in stage2.
rank: destination rank in communication group which enumerate receive ranks.
src_rank: source rank in communication group. always 0.
group_name: communication group name.
pipe_stage: pipeline stage. default 0.
stage2: bool. whether stage2 or not. default False.
Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1
stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)]
stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)]
def _broadcast_parameter_opt_level_1(self, rank, src_rank, group_name, pipe_stage=0):
debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1")
def tensor_generator():
for name, param in self._parameters_to_sync[pipe_stage]:
if self._expert_sync_buffer and name in self._expert_sync_buffer:
yield self._expert_sync_buffer[name]
# move self._expert_sync_buffer[name] to cpu mem to save gpu mem
cpu_expert = self._expert_sync_buffer[name].cpu()
del self._expert_sync_buffer[name]
self._expert_sync_buffer[name] = cpu_expert
else:
yield param.data

For stage1 pair (0, 8):
1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0.
2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1.
bucket_generator = bucket_tensor_generator(tensor_generator, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
dense_bucket_num = 0
sparse_bucket_num = 0
tensor_changed = rank != src_rank
for bucket_or_tensor, is_dense in bucket_generator:
if is_dense:
coalesced_comm_dense(bucket_or_tensor, col.broadcast, extra_args=(src_rank, group_name), tensor_changed=tensor_changed)
dense_bucket_num += 1
else:
col.broadcast(bucket_or_tensor, src_rank, group_name)
sparse_bucket_num += 1

After (0, 8), to_rank 8 received tensor slices of 8 and 9.
debug_rank_0(f"{self.name} Got dense_buckets {dense_bucket_num}, spase_bucket {sparse_bucket_num}", self._logger)
self.empty_cache()

For stage2 pair (8, 9):
1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0.
2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1.
In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer.
def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
"""
:meta private:
"""
if self.runtime_args.sync_memory_optimization_level == 0:
self._broadcast_parameter_opt_level_0(rank, src_rank, group_name, pipe_stage)
else:
self._broadcast_parameter_opt_level_1(rank, src_rank, group_name, pipe_stage)

def _broadcast_parameter_two_stage_opt_level_0(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 0")
tensor_changed = rank != src_rank

if stage2:
if tensor_changed:
parameters_to_sync = self._parameters_to_recv[to_rank]
else:
parameters_to_sync = self._parameters_to_send
else:
del self._sync_buffer
self._sync_buffer = defaultdict(list)
parameters_to_sync = self._parameters_to_sync

tensors = []
buffer_num = []
if stage2 and not tensor_changed and self._sync_buffer:# pylint: disable=too-many-nested-blocks
idx = 0
for name, param in parameters_to_sync[pipe_stage]:
self._logger.debug(
f"Adding {name} to sync for if branch from "
f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}"
)
tensors.append(self._sync_buffer[buffer_rank % self.tp_num_mapping][idx])
buffer_num.append(1)
idx += 1
del self._sync_buffer[buffer_rank % self.tp_num_mapping]
else:
for name, param in parameters_to_sync[pipe_stage]:
self._logger.debug(
f"Adding {name} to sync for else branch from "
f"src_rank: {src_rank} to rank: {rank} in pipe_stage {pipe_stage}"
)
param_data = param.data
if rank and self._buffer_num and not stage2:
assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
buffer_num.append(self._buffer_num[name])
elif stage2:
buffer_num.append(1)
else:
# regroup src_tensor by tp_rank.
param_data = self._synchronizer.regroup_params_to_sync(name, param_data, self._tp_division[name])
buffer_num.append(1)
tensors.append(param_data)

assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors_two_stage(
tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb,
buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, sparse_bucket {len(sparse_bucket)}", self._logger)

for bucket in dense_buckets:
index = 0 if stage2 else (to_rank % self.tp_num_mapping)
all_buffers = coalesced_comm_dense_two_stage(
bucket, col.broadcast, rank,
extra_args=(src_rank, group_name), tensor_changed=tensor_changed,
stage2=stage2, index=index)
if tensor_changed and not stage2:
for key, value in all_buffers.items():
self._sync_buffer[key] += value

for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

self.empty_cache()

def _broadcast_parameter_two_stage_opt_level_1(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
debug_rank_0(">>>>>>>>>>>>>>broadcast parameter at memory optimization level 1")
tensor_changed = rank != src_rank

if stage2:
Expand Down Expand Up @@ -904,6 +983,36 @@ def tensor_generator():

self.empty_cache()

def broadcast_parameter_two_stage(self, to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
"""
Arguments:
to_rank: receive rank in mapping from trainer to inference model.
buffer_rank: index which tensors of sync buffer to be sended in stage2.
rank: destination rank in communication group which enumerate receive ranks.
src_rank: source rank in communication group. always 0.
group_name: communication group name.
pipe_stage: pipeline stage. default 0.
stage2: bool. whether stage2 or not. default False.
Example: trainer_tp = 4, inference_tp = 8. pipeline_size = 1
stage1: [(from_rank, to_rank), ...] = [(0, 8), (1, 10), (2, 12), (3, 14)]
stage2: [(from_rank, to_rank), ...] = [(8, 9), (10, 11), (12, 13), (14, 15)]

For stage1 pair (0, 8):
1. call broadcast func: (0 -> 0). src_rank: 0, rank: 0.
2. call broadcast func: (0 -> 8). src_rank: 0, rank: 1.

After (0, 8), to_rank 8 received tensor slices of 8 and 9.

For stage2 pair (8, 9):
1. call broadcast func: (8 -> 8). src_rank: 0, rank: 0.
2. call broadcast func: (8 -> 9). src_rank: 0, rank: 1.
In (8 -> 8), we need to send tp_slice of 'to_rank' 9, so set buffer_rank 9 to fetch tensors in sync buffer.
"""
if self.runtime_args.sync_memory_optimization_level == 0:
self._broadcast_parameter_two_stage_opt_level_0(to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage, stage2)
else:
self._broadcast_parameter_two_stage_opt_level_1(to_rank, buffer_rank, rank, src_rank, group_name, pipe_stage, stage2)

def send_parameter(self, dst_rank, group_name, pipe_stage=0):
"""
:meta private:
Expand Down
4 changes: 2 additions & 2 deletions chatlearn/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,9 +296,9 @@ def learn(self):
self.timers("sync_parameters").start()
self.model_manager.sync_parameters(requires_grad=False, validate=self.runtime_args.validate_param_sync)
self.timers("sync_parameters").stop()
logger.info(f"{LOG_START} " + get_full_proc_memory_info('After first param sync'))
logger.info(
f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} " \
+ get_full_proc_memory_info('After first param sync')
f"{LOG_START} {self._name} sync_parameters summary {self.timers.log(names=['sync_parameters'])} "
)
self._data_loader = data_loader
for episode_id in range(self._start_episode, self.runtime_args.num_episode):
Expand Down
3 changes: 3 additions & 0 deletions chatlearn/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ class RuntimeConfig(BaseConfig):
param_sync_max_workers: int = None
#: communication type to regroup routed experts, allgather/alltoall
routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL
#: memory optimization level in synchronization to decide whether save gpu memory or persue faster execution runtime, 0/1
sync_memory_optimization_level: int = 1
#: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes
#: if `max_relay_episode` is set to 0, then relay is disabled
max_relay_episode: int = 0
Expand Down Expand Up @@ -510,6 +512,7 @@ def _validate_params(self):
assert self.runtime_args.stream_data_loader_type.lower() in ["fixed", "dynamic"]
assert self.runtime_args.cpu_schedule_strategy in [strategy.value for strategy in RAY_PG_STRATEGY]
assert self.runtime_args.param_sync_comm_type in list(PARAM_SYNC_COMM_TYPE)
assert self.runtime_args.sync_memory_optimization_level in [0, 1]
for model_name, model_args in self.models.items():
if model_args.num_gpu >= 1:
if model_args.gpu_per_process is None:
Expand Down
73 changes: 72 additions & 1 deletion chatlearn/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,74 @@ def bucket_tensors(tensors, bucket_size_mb):
return dense_buckets, sparse_bucket


def bucket_tensor_generator(tensor_generator, bucket_size_mb):
"""Group tensors into chunks. We seperate sparse and dense tensor,
each containing tensors of same type up to certain byte limit in total size.

Args:
tensor_generator (Generator): A generator of tensors to be separated into chunks.
size_limit (int): The limit of each chunk in bytes.

Yield:
dense_buckets: Blocks of tensors of same type and within size_limit.
sparse_bucket: A list of sparse tensors
"""
size_limit = bucket_size_mb * 1024 * 1024
buf_dict = defaultdict(lambda: [[], 0])
for tensor in tensor_generator():
if tensor.is_sparse:
yield tensor, False
continue
t = tensor.type()
size = tensor.numel() * tensor.element_size()
buf_and_size = buf_dict[t]
if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison
yield buf_and_size[0], True
buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append(tensor)
buf_and_size[1] += size
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf, True


def bucket_tensors_two_stage(tensors, bucket_size_mb, buffer_num=None, tensor_changed=False):
"""Group tensors into chunks. We seperate sparse and dense tensor,
each containing tensors of same type up to certain byte limit in total size.
Args:
tensors (Sequence): A sequence of tensors to be separated into chunks.
size_limit (int): The limit of each chunk in bytes.
Return:
dense_buckets: Blocks of tensors of same type and within size_limit.
sparse_bucket: A list of sparse tensors
"""
size_limit = bucket_size_mb * 1024 * 1024
buf_dict = defaultdict(lambda: [[], 0])
dense_buckets = []
sparse_bucket = []
for idx, tensor in enumerate(tensors):
buffer_multiple = 1 if buffer_num is None else buffer_num[idx]
if tensor.is_sparse:
sparse_bucket.append(tensor)
continue
t = tensor.type()
# expand buffer size of dst ranks which recv tensor from trainer.
size = tensor.numel() * tensor.element_size() * buffer_multiple
buf_and_size = buf_dict[t]
if size_limit > 0 and buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: # pylint: disable=chained-comparison
dense_buckets.append(buf_and_size[0])
buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append((torch.empty(size=[tensor.numel() * buffer_multiple],
dtype=tensor.dtype,
device=tensor.device) if (tensor_changed and buffer_multiple > 1) else tensor,
[size // tensor.element_size(), buffer_multiple, tensor]))
buf_and_size[1] += size
for buf, size in buf_dict.values():
if len(buf) > 0:
dense_buckets.append(buf)
return dense_buckets, sparse_bucket


def bucket_tensors_two_stage_generator(tensor_generator, bucket_size_mb, stage2=False, tensor_changed=False):
"""Group tensors into chunks. We seperate sparse and dense tensor,
each containing tensors of same type up to certain byte limit in total size.
Expand Down Expand Up @@ -130,8 +198,11 @@ def coalesced_comm_dense(bucket, comm_call, extra_args, tensor_changed=True):
flat_tensors = _flatten_dense_tensors(bucket)
comm_call(flat_tensors, *extra_args)
if tensor_changed:
all_buffers = _unflatten_dense_tensors(flat_tensors, bucket)
del flat_tensors
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
bucket, all_buffers
):
tensor.copy_(synced)


Expand Down
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/grpo_math_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ runtime:
max_relay_episode: 1
exp_name: ${exp_name:chatlearn}
validate_param_sync: ${validate_param_sync:False}
sync_memory_optimization_level: ${sync_memory_optimization_level:1}
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/online_dpo_vllm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,4 @@ runtime:
output_dir: ${output_dir}
exp_name: ${exp_name:chatlearn}
validate_param_sync: ${validate_param_sync:False}
sync_memory_optimization_level: ${sync_memory_optimization_level:1}
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/vllm_param_sync.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ runtime:
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
sync_memory_optimization_level: ${sync_memory_optimization_level:1}
1 change: 1 addition & 0 deletions examples/megatron/configs/llama2/vllm_rlhf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,4 @@ runtime:
exp_name: ${exp_name:chatlearn}
debug: ${debug:False}
validate_param_sync: ${validate_param_sync:False}
sync_memory_optimization_level: ${sync_memory_optimization_level:1}
Loading