Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions megatron/core/distributed/param_and_grad_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,10 +419,10 @@ def start_grad_sync(self, force_all_reduce: Optional[bool] = False):
# need to overlap communication.
stream_context = torch.cuda.stream(self.communication_stream)

# The RS/AR communication stream needs to wait for the default stream
# The RS/AR communication stream needs to wait for the current stream
# to complete its gradient computation before launching the next
# gradient reduction collective.
self.communication_stream.wait_stream(torch.cuda.default_stream())
self.communication_stream.wait_stream(torch.cuda.current_stream())
else:
stream_context = nullcontext()

Expand Down Expand Up @@ -529,7 +529,7 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False):
# When using multiple DistOpt instances, we don't need to sync here as we launch
# communications on a separate communication stream.
if self.ddp_config.num_distributed_optimizer_instances > 1:
torch.cuda.default_stream().wait_stream(self.communication_stream)
torch.cuda.current_stream().wait_stream(self.communication_stream)
return
assert self.grad_reduce_handle is not None, (
f"Communication call has not been issued for this bucket "
Expand Down