diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index bf00717ab6c..09690dc813c 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -564,7 +564,15 @@ def backward(ctx, *args): *filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args)) ) torch.autograd.backward(outputs, args) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + grads = tuple(inp.grad.clone() if isinstance(inp, torch.Tensor) and inp.grad is not None else inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + cur_stream = torch.cuda.current_stream() + for t in detached_inputs: + if isinstance(t, torch.Tensor) and t.requires_grad: + t.record_stream(cur_stream) + t.untyped_storage().resize_(0) + if t.grad is not None: + t.grad.record_stream(cur_stream) + t.grad.untyped_storage().resize_(0) _unset_checkpointing() return (None, None) + grads