@@ -69,6 +69,103 @@ def test_dispatch_and_combine(self, dtype, device):
6969
7070 self .assertTrue (torch .allclose (all2all_results ["hidden_states" ], agrs_results ["hidden_states" ], atol = 1e-2 , rtol = 1e-2 ))
7171
72+ logits_list = [torch .randn (seq_len , num_experts ).cuda () for _ in range (2 )]
73+ router_out_list = [router (logits ) for logits in logits_list ]
74+ hidden_states_list = [torch .rand (seq_len , hidden_size ).to (device ).to (dtype ) for _ in range (2 )]
75+ all2all_results_list = self ._dispatcher_call_micro_batch (
76+ dispatcher = all2all_dispatcher ,
77+ hidden_states_list = hidden_states_list ,
78+ topk_ids_list = [router_out ["topk_ids" ] for router_out in router_out_list ],
79+ topk_weights_list = [router_out ["topk_weights" ] for router_out in router_out_list ],
80+ )
81+ agrs_results_list = self ._dispatcher_call_micro_batch (
82+ dispatcher = agrs_dispatcher ,
83+ hidden_states_list = hidden_states_list ,
84+ topk_ids_list = [router_out ["topk_ids" ] for router_out in router_out_list ],
85+ topk_weights_list = [router_out ["topk_weights" ] for router_out in router_out_list ],
86+ )
87+ torch .distributed .breakpoint ()
88+
89+ def _dispatcher_call_micro_batch (
90+ self ,
91+ dispatcher : DispacherInterface ,
92+ hidden_states_list : torch .Tensor ,
93+ topk_ids_list : torch .Tensor ,
94+ topk_weights_list : torch .Tensor
95+ ):
96+ intra_layer_micro_batch = len (hidden_states_list )
97+ pre_dispatched_list = []
98+ for hidden_states , topk_ids in zip (hidden_states_list , topk_ids_list ):
99+ pre_dispatched = dispatcher .dispatch_preprocess (
100+ hidden_states = hidden_states ,
101+ topk_ids = topk_ids ,
102+ async_op = True ,
103+ )
104+ pre_dispatched_list .append (pre_dispatched )
105+
106+ dispatched_list = []
107+ post_dispatched_list = []
108+ experts_out_list = []
109+ pre_combined_list = []
110+ combined_list = []
111+
112+ for topk_weights , pre_dispatched in zip (topk_weights_list , pre_dispatched_list ):
113+ dispatched = dispatcher .dispatch (
114+ pre_dispatched = pre_dispatched ,
115+ topk_weights = topk_weights ,
116+ async_op = True ,
117+ )
118+ post_dispatched = dispatcher .dispatch_postprocess (
119+ pre_dispatched = pre_dispatched ,
120+ dispatched = dispatched ,
121+ async_op = True ,
122+ )
123+ experts_results = mock_experts (
124+ hidden_states = post_dispatched ["hidden_states" ],
125+ tokens_per_exprts = post_dispatched ["tokens_per_expert" ],
126+ )
127+ pre_combined = dispatcher .combine_preprocess (
128+ hidden_states = experts_results ,
129+ pre_dispatched = pre_dispatched ,
130+ dispatched = dispatched ,
131+ post_dispatched = post_dispatched ,
132+ async_op = True ,
133+ )
134+ post_dispatched_list .append (post_dispatched )
135+ experts_out_list .append (experts_results )
136+ dispatched_list .append (dispatched )
137+ pre_combined_list .append (pre_combined )
138+
139+ for pre_combined , pre_dispatched , dispatched , post_dispatched in zip (
140+ pre_combined_list ,
141+ pre_dispatched_list ,
142+ dispatched_list ,
143+ post_dispatched_list ,
144+ ):
145+ combined = dispatcher .combine (
146+ pre_combined = pre_combined ,
147+ pre_dispatched = pre_dispatched ,
148+ dispatched = dispatched ,
149+ post_dispatched = post_dispatched ,
150+ async_op = True ,
151+ )
152+ combined_list .append (combined )
153+
154+ hidden_states_out_list : list [torch .Tensor ] = []
155+
156+ for i in range (intra_layer_micro_batch ):
157+ post_combined = dispatcher .combine_postprocess (
158+ pre_dispatched = pre_dispatched_list [i ],
159+ dispatched = dispatched_list [i ],
160+ post_dispatched = post_dispatched_list [i ],
161+ pre_combined = pre_combined_list [i ],
162+ combined = combined_list [i ],
163+ async_op = True ,
164+ )
165+ hidden_states_out_list .append (post_combined ["hidden_states" ])
166+ return hidden_states_out_list
167+
168+
72169 def _dispatcher_call (
73170 self ,
74171 dispatcher : DispacherInterface ,
0 commit comments