-
Notifications
You must be signed in to change notification settings - Fork 316
Open
Description
In the attached schedule diagram, Device 5 sends the input gradients for microbatch 0 at the end of T10, but Device 4 attempts to receive them at the beginning of T12. Why doesn't this misaligned communication delay the backward pass for the weights of microbatch 0 (the blue 0 happening at T11), given that it must wait for the communication to complete? Note: time is numbered from T0 from left to right.
Following is the related code in dualpipe.py:
step_3 = num_half_ranks - half_rank - 1
for i in range(step_3):
self._backward_chunk(1, enable_zb=True) # enable zero bubble - decouple backward for inputs and backward for weights
self._recv_forward(1)
self._weight_chunk()
self._forward_chunk(1, recv=False)
# Step 4 (Main step): nF0B1F1B0
step_4 = half_num_chunks - num_ranks + half_rank + 1
for i in range(step_4):
if i == 0:
if self.is_middle_rank:
# NOTE: We don't overlap these two chunks to further reduce bubble size.
self._forward_chunk(0, recv=False, send=False)
self._send_forward(1)
self._backward_chunk(1, send=False)
self._send_forward(0)
self._send_backward(1)
else:
self._forward_backward_chunk(0, 1, recv0=False)
else:
self._forward_backward_chunk(0, 1)
self._forward_backward_chunk(1, 0)Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels