From 1dd2f4e9558c531330bc9fcbe3f6661fd4e228f3 Mon Sep 17 00:00:00 2001 From: Begunner Date: Thu, 5 Feb 2026 18:12:13 +0800 Subject: [PATCH] add release to checkpoint backward --- megatron/core/tensor_parallel/random.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index bf00717ab6c..09690dc813c 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -564,7 +564,15 @@ def backward(ctx, *args): *filter(lambda x: torch.is_tensor(x[0]) and x[0].requires_grad, zip(outputs, args)) ) torch.autograd.backward(outputs, args) - grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + grads = tuple(inp.grad.clone() if isinstance(inp, torch.Tensor) and inp.grad is not None else inp.grad if isinstance(inp, torch.Tensor) else inp for inp in detached_inputs) + cur_stream = torch.cuda.current_stream() + for t in detached_inputs: + if isinstance(t, torch.Tensor) and t.requires_grad: + t.record_stream(cur_stream) + t.untyped_storage().resize_(0) + if t.grad is not None: + t.grad.record_stream(cur_stream) + t.grad.untyped_storage().resize_(0) _unset_checkpointing() return (None, None) + grads