Skip to content

Commit a58078a

Browse files
committed
[pytest] test micro batch overlap in agrs dispatcher
1 parent d8b918b commit a58078a

1 file changed

Lines changed: 97 additions & 0 deletions

File tree

tests/module/dispatcher/test_agrs_all2all.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,103 @@ def test_dispatch_and_combine(self, dtype, device):
6969

7070
self.assertTrue(torch.allclose(all2all_results["hidden_states"], agrs_results["hidden_states"], atol=1e-2, rtol=1e-2))
7171

72+
logits_list = [torch.randn(seq_len, num_experts).cuda() for _ in range(2)]
73+
router_out_list = [router(logits) for logits in logits_list]
74+
hidden_states_list = [torch.rand(seq_len, hidden_size).to(device).to(dtype) for _ in range(2)]
75+
all2all_results_list = self._dispatcher_call_micro_batch(
76+
dispatcher=all2all_dispatcher,
77+
hidden_states_list=hidden_states_list,
78+
topk_ids_list=[router_out["topk_ids"] for router_out in router_out_list],
79+
topk_weights_list=[router_out["topk_weights"] for router_out in router_out_list],
80+
)
81+
agrs_results_list = self._dispatcher_call_micro_batch(
82+
dispatcher=agrs_dispatcher,
83+
hidden_states_list=hidden_states_list,
84+
topk_ids_list=[router_out["topk_ids"] for router_out in router_out_list],
85+
topk_weights_list=[router_out["topk_weights"] for router_out in router_out_list],
86+
)
87+
torch.distributed.breakpoint()
88+
89+
def _dispatcher_call_micro_batch(
90+
self,
91+
dispatcher: DispacherInterface,
92+
hidden_states_list: torch.Tensor,
93+
topk_ids_list: torch.Tensor,
94+
topk_weights_list: torch.Tensor
95+
):
96+
intra_layer_micro_batch = len(hidden_states_list)
97+
pre_dispatched_list = []
98+
for hidden_states, topk_ids in zip(hidden_states_list, topk_ids_list):
99+
pre_dispatched = dispatcher.dispatch_preprocess(
100+
hidden_states=hidden_states,
101+
topk_ids=topk_ids,
102+
async_op=True,
103+
)
104+
pre_dispatched_list.append(pre_dispatched)
105+
106+
dispatched_list = []
107+
post_dispatched_list = []
108+
experts_out_list = []
109+
pre_combined_list = []
110+
combined_list = []
111+
112+
for topk_weights, pre_dispatched in zip(topk_weights_list, pre_dispatched_list):
113+
dispatched = dispatcher.dispatch(
114+
pre_dispatched=pre_dispatched,
115+
topk_weights=topk_weights,
116+
async_op=True,
117+
)
118+
post_dispatched = dispatcher.dispatch_postprocess(
119+
pre_dispatched=pre_dispatched,
120+
dispatched=dispatched,
121+
async_op=True,
122+
)
123+
experts_results = mock_experts(
124+
hidden_states=post_dispatched["hidden_states"],
125+
tokens_per_exprts=post_dispatched["tokens_per_expert"],
126+
)
127+
pre_combined = dispatcher.combine_preprocess(
128+
hidden_states=experts_results,
129+
pre_dispatched=pre_dispatched,
130+
dispatched=dispatched,
131+
post_dispatched=post_dispatched,
132+
async_op=True,
133+
)
134+
post_dispatched_list.append(post_dispatched)
135+
experts_out_list.append(experts_results)
136+
dispatched_list.append(dispatched)
137+
pre_combined_list.append(pre_combined)
138+
139+
for pre_combined, pre_dispatched, dispatched, post_dispatched in zip(
140+
pre_combined_list,
141+
pre_dispatched_list,
142+
dispatched_list,
143+
post_dispatched_list,
144+
):
145+
combined = dispatcher.combine(
146+
pre_combined=pre_combined,
147+
pre_dispatched=pre_dispatched,
148+
dispatched=dispatched,
149+
post_dispatched=post_dispatched,
150+
async_op=True,
151+
)
152+
combined_list.append(combined)
153+
154+
hidden_states_out_list: list[torch.Tensor] = []
155+
156+
for i in range(intra_layer_micro_batch):
157+
post_combined = dispatcher.combine_postprocess(
158+
pre_dispatched=pre_dispatched_list[i],
159+
dispatched=dispatched_list[i],
160+
post_dispatched=post_dispatched_list[i],
161+
pre_combined=pre_combined_list[i],
162+
combined=combined_list[i],
163+
async_op=True,
164+
)
165+
hidden_states_out_list.append(post_combined["hidden_states"])
166+
return hidden_states_out_list
167+
168+
72169
def _dispatcher_call(
73170
self,
74171
dispatcher: DispacherInterface,

0 commit comments

Comments
 (0)