diff --git a/megatron/core/distributed/param_and_grad_buffer.py b/megatron/core/distributed/param_and_grad_buffer.py index 088374fbf13..85b9d98a3be 100644 --- a/megatron/core/distributed/param_and_grad_buffer.py +++ b/megatron/core/distributed/param_and_grad_buffer.py @@ -419,10 +419,10 @@ def start_grad_sync(self, force_all_reduce: Optional[bool] = False): # need to overlap communication. stream_context = torch.cuda.stream(self.communication_stream) - # The RS/AR communication stream needs to wait for the default stream + # The RS/AR communication stream needs to wait for the current stream # to complete its gradient computation before launching the next # gradient reduction collective. - self.communication_stream.wait_stream(torch.cuda.default_stream()) + self.communication_stream.wait_stream(torch.cuda.current_stream()) else: stream_context = nullcontext() @@ -529,7 +529,7 @@ def finish_grad_sync(self, force_all_reduce: Optional[bool] = False): # When using multiple DistOpt instances, we don't need to sync here as we launch # communications on a separate communication stream. if self.ddp_config.num_distributed_optimizer_instances > 1: - torch.cuda.default_stream().wait_stream(self.communication_stream) + torch.cuda.current_stream().wait_stream(self.communication_stream) return assert self.grad_reduce_handle is not None, ( f"Communication call has not been issued for this bucket "