Skip to content

Gradient synchronization incorrect when --overlap-grad-reduce and --num-distributed-optimizer-instances > 1 due to autograd hook stream affinity #3670

@zyeric

Description

@zyeric

Describe the bug

In Megatron-LM v0.15.2 (and likely the affected commit range), enabling both --overlap-grad-reduce and --num-distributed-optimizer-instances > 1 results in incorrectly synchronized gradients. The root cause is a stream synchronization issue in the distributed data parallel (DDP) gradient post‑hook.

When DistributedDataParallel is constructed, it is wrapped in a custom CUDA stream context (introduced in PR #e7c55de). Inside that context, a temporary tensor is created via expand_as to retrieve the gradient accumulator (grad_acc). Because expand_as is executed on a non‑default stream, the grad_acc object inherits that stream’s affinity. Consequently, the post‑hook registered on grad_acc runs on the same side stream during backward.

The hook contains gradient checks and scaling that must complete before the reduce‑scatter operation on the communication stream starts. However, because the hook runs on a side stream and no explicit synchronization is performed, the communication stream may launch reduce‑scatter before the hook finishes, leading to incorrect gradient values.

This problem only manifests when --num-distributed-optimizer-instances > 1 because that setting changes the way gradient communication is orchestrated, making it sensitive to the stream ordering.


Steps/Code to reproduce bug

A minimal unit test that demonstrates the underlying stream affinity issue (without requiring a full training run) is provided below. Save it as test_stream_affinity.py and run with pytest -s or python -m unittest test_stream_affinity.py.

import torch
import unittest

class TestAutogradStreamAffinity(unittest.TestCase):
    def setUp(self):
        if not torch.cuda.is_available():
            self.skipTest("CUDA not available")
        self.device = 'cuda'

    def _make_hook(self, label):
        def hook(grad_out, grad_in):
            current_id = torch.cuda.current_stream().stream_id
            print(f"\n[{label}] Hook running on stream {current_id}")
            return grad_out
        return hook

    def test_case_1_pure_registration(self):
        """Node created in default stream, hook registered in side stream."""
        param = torch.randn(10, requires_grad=True, device=self.device)
        param_tmp = param.expand_as(param)          # default stream
        grad_acc = param_tmp.grad_fn.next_functions[0][0]

        custom_stream = torch.cuda.Stream()
        with torch.cuda.stream(custom_stream):
            print(f"\n[Case 1] Registering on stream {torch.cuda.current_stream().stream_id}")
            grad_acc.register_hook(self._make_hook("Pure Registration"))

        loss = (param * 2).sum()
        loss.backward()

    def test_case_2_graph_binding(self):
        """Node created and hook registered in side stream."""
        param = torch.randn(10, requires_grad=True, device=self.device)
        custom_stream = torch.cuda.Stream()

        with torch.cuda.stream(custom_stream):
            print(f"\n[Case 2] Creating node and registering on stream {torch.cuda.current_stream().stream_id}")
            param_tmp = param.expand_as(param)      # side stream
            grad_acc = param_tmp.grad_fn.next_functions[0][0]
            grad_acc.register_hook(self._make_hook("Graph Binding"))

            loss = (param_tmp * 2).sum()

        loss.backward()

When run, you will see output similar to:

[Case 1] Registering on stream 3
[Pure Registration] Hook running on stream 0
.
[Case 2] Creating node and registering on stream 35
[Graph Binding] Hook running on stream 35
.

In test_case_2 the hook runs on the side stream (the stream used during expand_as), exactly mirroring the situation in Megatron’s DDP construction. This stream‑affinity causes the hook’s operations to execute on a stream that is not synchronised with the communication stream, leading to race conditions.

To reproduce the full bug in a real training setup:

  • Use Megatron‑LM v0.15.2 (or commit before the fix).
  • Launch a training job with both flags:
    --overlap-grad-reduce
    --num-distributed-optimizer-instances 2
    
  • Observe incorrect gradient values (e.g., grad norm is much larger than overlap is turned off or instance num=1).

Expected behavior

Gradients should be correctly synchronised regardless of the stream on which the hook executes. The communication stream should wait for all hook operations (gradient checks/scaling) to complete before performing reduce‑scatter.


Additional context

  • The problematic change was introduced in PR: e7c55de (wrapping DDP construction in a CUDA stream context).
  • The root cause is that expand_as inside that context makes the gradient accumulator inherit the side stream, causing the post‑hook to run on that stream.
  • A fix has been implemented (or is proposed) that replaces the synchronisation from waiting on the default stream to waiting on the current stream inside the hook. This ensures the communication stream waits for the hook’s stream before launching reduce‑scatter.
  • The unit test above isolates the stream‑affinity behaviour and can be used to verify the fix.

Tagging @mcore-oncall for awareness.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions