diff --git a/chatlearn/synchronizer/megatron_vllm.py b/chatlearn/synchronizer/megatron_vllm.py index ae334851..ec79a6e5 100644 --- a/chatlearn/synchronizer/megatron_vllm.py +++ b/chatlearn/synchronizer/megatron_vllm.py @@ -33,6 +33,7 @@ class MegatronVllmSync(BaseSync): def __init__(self, src_model, dst_model): super().__init__(src_model, dst_model) self.src_module_args = src_model.module_args + self.dst_module_args = dst_model.module_args self.is_parameter_changed = True @abstractmethod @@ -322,15 +323,17 @@ def transform_parameters(self, params_to_sync_list): params_to_sync_list = self.fix_shared_expert_ordering(params_to_sync_list) return params_to_sync_list - def regroup_qkv_tp_slices(self, name, param_data, tp_divition): + def regroup_qkv_tp_slices(self, name, param_data, tp_division): param_data_shape = param_data.shape # Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend. to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict + # pylint: disable=too-many-nested-blocks if "attention.query_key_value" in name or \ "self_attention.query_key_value" in name or \ "self_attention.linear_qkv" in name: - tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] - heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size + src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"] + dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"] + heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"] param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:] @@ -340,31 +343,58 @@ def regroup_qkv_tp_slices(self, name, param_data, tp_divition): if to_fix_qkv_ordering_dict is not None: param_data = param_data.view(param_shape) param_data_list = [] - head_offset = heads // tp_divition - for idx in range(tp_divition): + head_offset = heads // tp_division + for idx in range(tp_division): start = idx * head_offset end = start + head_offset param_data_list.append(param_data[:,start:end]) param_data = torch.concat(param_data_list, dim=0).view(param_data_shape) del param_data_list else: - _num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size \ - if self.src_module_args.args_dict["group_query_attention"] else heads - if to_fix_qkv_ordering_dict is not None or _num_query_groups == 1: + if self.src_module_args.args_dict["group_query_attention"]: + num_query_groups = self.src_module_args.args_dict["num_query_groups"] + assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], ( + f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of " + f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config." + ) + src_num_query_groups_per_replica = num_query_groups // src_tp_size + if dst_tp_size >= num_query_groups: + num_dst_kv_head_replicas = dst_tp_size // num_query_groups + else: + num_dst_kv_head_replicas = 1 + else: + src_num_query_groups_per_replica = heads + num_dst_kv_head_replicas = 1 + + if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1: if len(param_data_shape) == 1: - param_data = param_data.view((heads + 2 * _num_query_groups, hidden_size_per_head)) + param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head)) else: param_data = param_data.view( - (heads + 2 * _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])) + (heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])) param_data_list = [] - head_offset = heads // tp_divition - for idx in range(tp_divition): + head_offset = heads // tp_division + for idx in range(tp_division): q_start = idx * head_offset q_end = q_start + head_offset - k_start = (heads + idx) if _num_query_groups // tp_divition else heads - k_end = k_start + 1 - v_start = k_start + _num_query_groups - v_end = v_start + 1 + if num_dst_kv_head_replicas == 1: + if src_num_query_groups_per_replica > tp_division: + assert src_num_query_groups_per_replica % tp_division == 0, ( + f"num_query_groups per replica of src model ({src_num_query_groups_per_replica}) " + f"must be divisible by tp_division ({tp_division}). Please double-check your config." + ) + kv_offset = src_num_query_groups_per_replica // tp_division + else: + kv_offset = 1 + k_start = (heads + idx) if src_num_query_groups_per_replica // tp_division else heads + k_end = k_start + kv_offset + v_start = k_start + src_num_query_groups_per_replica + v_end = v_start + kv_offset + else: + k_start = heads + idx // num_dst_kv_head_replicas + k_end = k_start + 1 + v_start = k_start + src_num_query_groups_per_replica + v_end = v_start + 1 q_proj = param_data[q_start:q_end].contiguous() k_proj = param_data[k_start:k_end].contiguous()