diff --git a/chatlearn/synchronizer/parameter_sync.py b/chatlearn/synchronizer/parameter_sync.py index 2267d80f..6daa5a8f 100644 --- a/chatlearn/synchronizer/parameter_sync.py +++ b/chatlearn/synchronizer/parameter_sync.py @@ -1180,6 +1180,15 @@ def setup_rank_mapping(self): f"greater or equal to expert parallel world size for inference ({self.num_dst_expert_parallel}) with HEP enabled." ) if self.dst_model.use_vllm_backend: + if ( + self.hep_num_mapping != 1 + and get_args().runtime_args.routed_expert_regrouping_comm_type == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL + ): + raise NotImplementedError( + "all-to-all routed expert weight is only supported when src TP size * src EP size = dst TP size. " + "Please consider setting `routed_expert_regrouping_comm_type` to allgather or adjusting the model's parallel size." + ) + if self.tp_num_mapping == 1: if self.ep_num_mapping == 1: self.build_rank_mapping() diff --git a/chatlearn/utils/arguments.py b/chatlearn/utils/arguments.py index c7d89e50..320a4d3d 100644 --- a/chatlearn/utils/arguments.py +++ b/chatlearn/utils/arguments.py @@ -311,7 +311,7 @@ class RuntimeConfig(BaseConfig): #: parameter sync max workers param_sync_max_workers: int = None #: communication type to regroup routed experts, allgather/alltoall - routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL + routed_expert_regrouping_comm_type: str = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER #: max number of relay episodes, if `max_relay_episode` is set to -1, then relay all episodes #: if `max_relay_episode` is set to 0, then relay is disabled max_relay_episode: int = 0 diff --git a/chatlearn/utils/constant.py b/chatlearn/utils/constant.py index 4ee59907..efa1e7a6 100644 --- a/chatlearn/utils/constant.py +++ b/chatlearn/utils/constant.py @@ -61,5 +61,5 @@ class PARAM_SYNC_COMM_TYPE(str, Enum): class ROUTED_EXPERT_REGROUPING_COMM_TYPE(str, Enum): """communication type of routed expert regrouping.""" - ALLTOALL = "alltoall" ALLGATHER = "allgather" + ALLTOALL = "alltoall" diff --git a/tests/test_hep_eptp_vllm_tp.py b/tests/test_hep_eptp_vllm_tp.py index 847234c4..af3a61fc 100644 --- a/tests/test_hep_eptp_vllm_tp.py +++ b/tests/test_hep_eptp_vllm_tp.py @@ -195,21 +195,21 @@ def test_hep_eptp_vllm_tp_dst_ep1_tp2_pp1_src_ep4_tp2_pp1(): assert param_sync_group.ep_num_mapping == tuples[0] / tuples[3] assert param_sync_group.tp_num_mapping == tuples[1] // tuples[4] - # Judge alltoall actors - alltoall_actors = param_sync_group.send_actors_to_regroup_routed_experts + # Judge allgather actors + allgather_actors = param_sync_group.send_actors_to_regroup_routed_experts actor2rank = param_sync_group.actor2rank - assert param_sync_group._comm_type_to_regroup_routed_experts == "alltoall" - assert len(alltoall_actors) == 1 - assert len(alltoall_actors[0]) == 8 # all src ranks should all-to-all routed experts + assert param_sync_group._comm_type_to_regroup_routed_experts == "allgather" + assert len(allgather_actors) == 1 + assert len(allgather_actors[0]) == 8 # all src ranks should all-to-all routed experts assert len(actor2rank) == 16 # all of the 16 actors should have rank assert len(set(list(actor2rank.values()))) == len(actor2rank) # all ranks should be unique - alltoall_actor_ranks = [] - for actor in alltoall_actors[0]: - alltoall_actor_ranks.append(actor2rank[actor]) + allgather_actor_ranks = [] + for actor in allgather_actors[0]: + allgather_actor_ranks.append(actor2rank[actor]) - assert alltoall_actor_ranks == [0, 1, 2, 3, 4, 5, 6, 7] + assert allgather_actor_ranks == [0, 1, 2, 3, 4, 5, 6, 7] # Judge src->dst rank mappings comm_pairs = [] diff --git a/tests/test_hep_eptppp_vllm_tp.py b/tests/test_hep_eptppp_vllm_tp.py index ce767fd0..4cf6ee68 100644 --- a/tests/test_hep_eptppp_vllm_tp.py +++ b/tests/test_hep_eptppp_vllm_tp.py @@ -197,24 +197,24 @@ def test_hep_eptppp_vllm_tp_dst_ep1_tp2_pp1_src_ep2_tp2_pp2(): assert param_sync_group.ep_num_mapping == tuples[0] / tuples[3] assert param_sync_group.tp_num_mapping == tuples[1] // tuples[4] - # Judge alltoall actors - alltoall_actors = param_sync_group.send_actors_to_regroup_routed_experts + # Judge allgather actors + allgather_actors = param_sync_group.send_actors_to_regroup_routed_experts actor2rank = param_sync_group.actor2rank - assert param_sync_group._comm_type_to_regroup_routed_experts == "alltoall" - assert len(alltoall_actors) == 2 - assert len(alltoall_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts - assert len(alltoall_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts + assert param_sync_group._comm_type_to_regroup_routed_experts == "allgather" + assert len(allgather_actors) == 2 + assert len(allgather_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts + assert len(allgather_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts assert len(actor2rank) == 16 # all of the 16 actors should have rank assert len(set(list(actor2rank.values()))) == len(actor2rank) # all ranks should be unique - alltoall_actor_ranks = [] - for actor_list in alltoall_actors: - alltoall_actor_ranks.append([]) + allgather_actor_ranks = [] + for actor_list in allgather_actors: + allgather_actor_ranks.append([]) for actor in actor_list: - alltoall_actor_ranks[-1].append(actor2rank[actor]) + allgather_actor_ranks[-1].append(actor2rank[actor]) - assert alltoall_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]] + assert allgather_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]] # Judge src->dst rank mappings comm_pairs = [] diff --git a/tests/test_hep_tp_vllm_tp.py b/tests/test_hep_tp_vllm_tp.py index 02a1406c..6027c97b 100644 --- a/tests/test_hep_tp_vllm_tp.py +++ b/tests/test_hep_tp_vllm_tp.py @@ -195,24 +195,24 @@ def test_hep_tp_vllm_tp_dst_ep1_tp4_pp1_src_ep1_tp4_pp1(): assert param_sync_group.ep_num_mapping == tuples[0] / tuples[3] assert param_sync_group.tp_num_mapping == tuples[1] // tuples[4] - # Judge alltoall actors - alltoall_actors = param_sync_group.send_actors_to_regroup_routed_experts + # Judge allgather actors + allgather_actors = param_sync_group.send_actors_to_regroup_routed_experts actor2rank = param_sync_group.actor2rank - assert param_sync_group._comm_type_to_regroup_routed_experts == "alltoall" - assert len(alltoall_actors) == 2 - assert len(alltoall_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts - assert len(alltoall_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts + assert param_sync_group._comm_type_to_regroup_routed_experts == "allgather" + assert len(allgather_actors) == 2 + assert len(allgather_actors[0]) == 4 # prev 4 src ranks should all-to-all routed experts + assert len(allgather_actors[1]) == 4 # last 4 src ranks should all-to-all routed experts assert len(actor2rank) == 16 # all of the 16 actors should have rank assert len(set(list(actor2rank.values()))) == len(actor2rank) # all ranks should be unique - alltoall_actor_ranks = [] - for actor_list in alltoall_actors: - alltoall_actor_ranks.append([]) + allgather_actor_ranks = [] + for actor_list in allgather_actors: + allgather_actor_ranks.append([]) for actor in actor_list: - alltoall_actor_ranks[-1].append(actor2rank[actor]) + allgather_actor_ranks[-1].append(actor2rank[actor]) - assert alltoall_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]] + assert allgather_actor_ranks == [[0, 1, 2, 3], [4, 5, 6, 7]] # Judge src->dst rank mappings comm_pairs = []