Skip to content

Backward Pass Not Delayed Despite Misaligned Gradient Communication #21

@curious-nizi

Description

@curious-nizi
Image

In the attached schedule diagram, Device 5 sends the input gradients for microbatch 0 at the end of T10, but Device 4 attempts to receive them at the beginning of T12. Why doesn't this misaligned communication delay the backward pass for the weights of microbatch 0 (the blue 0 happening at T11), given that it must wait for the communication to complete? Note: time is numbered from T0 from left to right.

Following is the related code in dualpipe.py:

        step_3 = num_half_ranks - half_rank - 1
        for i in range(step_3):
            self._backward_chunk(1, enable_zb=True) # enable zero bubble - decouple backward for inputs and backward for weights
            self._recv_forward(1)
            self._weight_chunk()
            self._forward_chunk(1, recv=False)

        # Step 4 (Main step): nF0B1F1B0
        step_4 = half_num_chunks - num_ranks + half_rank + 1
        for i in range(step_4):
            if i == 0:
                if self.is_middle_rank:
                    # NOTE: We don't overlap these two chunks to further reduce bubble size.
                    self._forward_chunk(0, recv=False, send=False)
                    self._send_forward(1)
                    self._backward_chunk(1, send=False)
                    self._send_forward(0)
                    self._send_backward(1)
                else:
                    self._forward_backward_chunk(0, 1, recv0=False)
            else:
                self._forward_backward_chunk(0, 1)
            self._forward_backward_chunk(1, 0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions