diff --git a/dualpipe/dualpipe.py b/dualpipe/dualpipe.py index 686fb53..9073dd0 100644 --- a/dualpipe/dualpipe.py +++ b/dualpipe/dualpipe.py @@ -382,7 +382,7 @@ def step( step_4 = half_num_chunks - num_ranks + half_rank + 1 for i in range(step_4): if i == 0: - if self.is_middle_rank: + if self.is_middle_rank and not self.is_first_rank: # NOTE: We don't overlap these two chunks to further reduce bubble size. self._forward_chunk(0, recv=False, send=False) self._send_forward(1) diff --git a/dualpipe/dualpipev.py b/dualpipe/dualpipev.py index cb8dc76..735b0e7 100644 --- a/dualpipe/dualpipev.py +++ b/dualpipe/dualpipev.py @@ -353,7 +353,7 @@ def step( step_4 = num_chunks - num_ranks * 2 + rank + 1 for i in range(step_4): if i == 0: - if self.is_last_rank: + if self.is_last_rank and not self.is_first_rank: # NOTE: We don't overlap these two chunks to further reduce bubble size. self._forward_chunk(0, recv=False, send=False) self._send_forward(1)