From ae109ffa8a899bdaa63cd701ca03b5184a746069 Mon Sep 17 00:00:00 2001 From: geyuhong Date: Tue, 2 Sep 2025 23:48:02 +0800 Subject: [PATCH 1/6] adapt grouped_linear, layernorm_linear and linear --- .../pytorch/module/grouped_linear.py | 28 ++++++++++++++++--- .../pytorch/module/layernorm_linear.py | 28 ++++++++++++++++--- transformer_engine/pytorch/module/linear.py | 28 ++++++++++++++++--- 3 files changed, 72 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efaca..18aef78c13 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -209,6 +209,24 @@ def forward( if isinstance(weight, QuantizedTensorBase): weight.update_usage(columnwise_usage=True) + offload_activation = False + if hasattr(inp, 'offloading_activation'): + offload_activation = True + for i in range(num_gemms): + inputmats[i].offloading_activation = inp.offloading_activation + ctx.offload_activation = offload_activation + + if offload_activation and cpu_offloading: + raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + + if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: + grad_added_to_main_grad_list = [] + for weight in weights: + if weight.requires_grad and hasattr(weight, 'grad_added_to_main_grad'): + grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) + weight.grad_added_to_main_grad = True + ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -271,11 +289,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: + if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - w.main_grad = main_grads[i] - weights[i] = w + if not ctx.cpu_offloading: + w = torch.nn.Parameter(weights[i], weights[i].requires_grad) + weights[i] = w + weights[i].main_grad = main_grads[i] + weights[i].grad_added_to_main_grad = ctx.grad_added_to_main_grad_list[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..9af38d8e76 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -424,10 +424,28 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + offload_activation = False + if hasattr(inp, 'offloading_activation'): + offload_activation = True + if inputmat.is_contiguous(): + inputmat = inputmat.contiguous() + ctx.offload_activation = offload_activation + + if offload_activation and cpu_offloading: + raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + + if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if hasattr(weight, 'grad_added_to_main_grad'): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -560,9 +578,11 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + if ctx.cpu_offloading or ctx.offload_activation: + if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object + if ctx.offload_activation: + origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..c725c92c11 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -395,10 +395,28 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") + offload_activation = False + if hasattr(inp, 'offload_activation'): + offload_activation = True + if saved_inputmat.is_contiguous(): + saved_inputmat = saved_inputmat.contiguous() + ctx.offload_activation = offload_activation + + if offload_activation and cpu_offloading: + raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + + if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if hasattr(weight, 'grad_added_to_main_grad'): + ctx.has_grad_added_to_main_grad = True + ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad + weight.grad_added_to_main_grad = True + else: + ctx.has_grad_added_to_main_grad = False + if cpu_offloading: - ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") + ctx.has_grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: + if ctx.has_grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. # You need to preserve the weight object to have all the attributes user @@ -493,9 +511,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading: - if ctx.grad_added_to_main_grad: + if ctx.cpu_offloading or ctx.offload_activation: + if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object + if ctx.offload_activation: + weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad From c44b45d37c7ac15a7ac999c908194587a5488e51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Sep 2025 16:11:35 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 8 +++++--- transformer_engine/pytorch/module/layernorm_linear.py | 8 +++++--- transformer_engine/pytorch/module/linear.py | 10 ++++++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 18aef78c13..c9402cded5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -210,19 +210,21 @@ def forward( weight.update_usage(columnwise_usage=True) offload_activation = False - if hasattr(inp, 'offloading_activation'): + if hasattr(inp, "offloading_activation"): offload_activation = True for i in range(num_gemms): inputmats[i].offloading_activation = inp.offloading_activation ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: - raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + raise ValueError( + f"Do not use offload_activation and cpu_offloading at the same time." + ) if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: grad_added_to_main_grad_list = [] for weight in weights: - if weight.requires_grad and hasattr(weight, 'grad_added_to_main_grad'): + if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): grad_added_to_main_grad_list.append(weight.grad_added_to_main_grad) weight.grad_added_to_main_grad = True ctx.grad_added_to_main_grad_list = grad_added_to_main_grad_list diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9af38d8e76..f8aa9a4718 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -425,17 +425,19 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") offload_activation = False - if hasattr(inp, 'offloading_activation'): + if hasattr(inp, "offloading_activation"): offload_activation = True if inputmat.is_contiguous(): inputmat = inputmat.contiguous() ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: - raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") + raise ValueError( + f"Do not use offload_activation and cpu_offloading at the same time." + ) if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: - if hasattr(weight, 'grad_added_to_main_grad'): + if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c725c92c11..751bc6832b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -396,17 +396,19 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") offload_activation = False - if hasattr(inp, 'offload_activation'): + if hasattr(inp, "offload_activation"): offload_activation = True if saved_inputmat.is_contiguous(): saved_inputmat = saved_inputmat.contiguous() ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: - raise ValueError(f"Do not use offload_activation and cpu_offloading at the same time.") - + raise ValueError( + f"Do not use offload_activation and cpu_offloading at the same time." + ) + if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: - if hasattr(weight, 'grad_added_to_main_grad'): + if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True From 93be702489eb5c588ed6704b8c6fa7f8fe41cd5b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 07:00:32 -0700 Subject: [PATCH 3/6] Bug fix Signed-off-by: Hongbin Liu --- .../pytorch/module/grouped_linear.py | 33 ++++++++++--------- .../pytorch/module/layernorm_linear.py | 18 ++++++---- transformer_engine/pytorch/module/linear.py | 17 ++++++---- 3 files changed, 41 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9402cded5..28c92156ee 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -14,6 +14,7 @@ from transformer_engine.common.recipe import Recipe from .base import ( get_multi_stream_cublas_workspace, + get_dummy_wgrad, TransformerEngineBaseModule, _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -80,6 +81,7 @@ def forward( module, skip_fp8_weight_update, save_original_input, + offload_activation, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -209,11 +211,10 @@ def forward( if isinstance(weight, QuantizedTensorBase): weight.update_usage(columnwise_usage=True) - offload_activation = False - if hasattr(inp, "offloading_activation"): - offload_activation = True - for i in range(num_gemms): - inputmats[i].offloading_activation = inp.offloading_activation + for i in range(num_gemms): + weights[i].offloading_activation = False + weights_fp8[i].offloading_activation = False + biases[i].offloading_activation = False ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: @@ -448,18 +449,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -506,6 +504,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): None, None, None, + None, *wgrad_list, *grad_biases, ) @@ -587,6 +586,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, + offload_activation: bool = False, delay_wgrad_compute: bool = False, save_original_input: bool = False, ) -> None: @@ -610,6 +610,8 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name + self.offload_activation = offload_activation + self.wgrad_store = WeightGradStore(delay_wgrad_compute) self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} @@ -825,6 +827,7 @@ def forward( self.sequence_parallel, self.activation_dtype, torch.is_grad_enabled(), + self.offload_activation, self, skip_fp8_weight_update, self.save_original_input, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f8aa9a4718..9eaec062b2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -122,6 +122,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_name: str, + offload_activation: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, @@ -424,11 +425,12 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - offload_activation = False - if hasattr(inp, "offloading_activation"): - offload_activation = True - if inputmat.is_contiguous(): - inputmat = inputmat.contiguous() + # Do not offload weights and biases + weight.offloading_activation = False + weightmat.offloading_activation = False + if bias is not None: + bias.offloading_activation = False + ln_weight.offloading_activation = False ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: @@ -441,6 +443,7 @@ def forward( ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True + ctx.weight_object = weight else: ctx.has_grad_added_to_main_grad = False @@ -1043,6 +1046,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # offload_activation None, # fsdp_group None, # debug None, # module @@ -1178,6 +1182,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, + offload_activation: bool = False, ) -> None: super().__init__() @@ -1194,7 +1199,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - + self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name @@ -1597,6 +1602,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_name, + self.offload_activation, self.fsdp_group, self, skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 751bc6832b..6d7864dda9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -109,6 +109,7 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, + offload_activation: bool, fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -395,11 +396,6 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - offload_activation = False - if hasattr(inp, "offload_activation"): - offload_activation = True - if saved_inputmat.is_contiguous(): - saved_inputmat = saved_inputmat.contiguous() ctx.offload_activation = offload_activation if offload_activation and cpu_offloading: @@ -412,6 +408,7 @@ def forward( ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad weight.grad_added_to_main_grad = True + ctx.weight_object = weight else: ctx.has_grad_added_to_main_grad = False @@ -426,6 +423,11 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight + # Do not offload weights and biases + weight.offloading_activation = False + weightmat.offloading_activation = False + if bias is not None: + bias.offloading_activation = False # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -990,6 +992,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name + None, # offload_activation None, # fp8_output None, # fsdp_group None, # module @@ -1112,6 +1115,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, + offload_activation: bool = False, ) -> None: super().__init__() @@ -1127,7 +1131,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input self.name = name - + self.offload_activation = offload_activation self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1474,6 +1478,7 @@ def forward( self.ub_bulk_dgrad, self.ub_bulk_wgrad, self.ub_name, + self.offload_activation, fp8_output, self.fsdp_group, self, From f0726f704a403dd390e86652485a279c4bfc2360 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 22:41:00 -0700 Subject: [PATCH 4/6] renaming Signed-off-by: Hongbin Liu --- .../pytorch/module/grouped_linear.py | 18 +++++++-------- .../pytorch/module/layernorm_linear.py | 22 +++++++++---------- transformer_engine/pytorch/module/linear.py | 22 +++++++++---------- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 28c92156ee..1361c4c217 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -81,7 +81,7 @@ def forward( module, skip_fp8_weight_update, save_original_input, - offload_activation, + fine_grained_activation_offloading, *weights_and_biases, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -215,14 +215,14 @@ def forward( weights[i].offloading_activation = False weights_fp8[i].offloading_activation = False biases[i].offloading_activation = False - ctx.offload_activation = offload_activation + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading - if offload_activation and cpu_offloading: + if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use offload_activation and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) - if offload_activation and weights[0].requires_grad and fuse_wgrad_accumulation: + if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: grad_added_to_main_grad_list = [] for weight in weights: if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): @@ -292,7 +292,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if (ctx.cpu_offloading or ctx.offload_activation) and ctx.fuse_wgrad_accumulation: + if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): if not ctx.cpu_offloading: w = torch.nn.Parameter(weights[i], weights[i].requires_grad) @@ -586,7 +586,7 @@ def __init__( ub_overlap_rs: bool = False, ub_overlap_ag: bool = False, ub_name: Optional[str] = None, - offload_activation: bool = False, + fine_grained_activation_offloading: bool = False, delay_wgrad_compute: bool = False, save_original_input: bool = False, ) -> None: @@ -610,7 +610,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name - self.offload_activation = offload_activation + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute) @@ -827,7 +827,7 @@ def forward( self.sequence_parallel, self.activation_dtype, torch.is_grad_enabled(), - self.offload_activation, + self.fine_grained_activation_offloading, self, skip_fp8_weight_update, self.save_original_input, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 9eaec062b2..55e226066a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -122,7 +122,7 @@ def forward( ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, ub_name: str, - offload_activation: bool, + fine_grained_activation_offloading: bool, fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, skip_fp8_weight_update: bool, @@ -431,14 +431,14 @@ def forward( if bias is not None: bias.offloading_activation = False ln_weight.offloading_activation = False - ctx.offload_activation = offload_activation + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading - if offload_activation and cpu_offloading: + if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use offload_activation and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) - if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad @@ -583,10 +583,10 @@ def backward( # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. - if ctx.cpu_offloading or ctx.offload_activation: + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: origin_weight = ctx.weight_object - if ctx.offload_activation: + if ctx.fine_grained_activation_offloading: origin_weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad @@ -1046,7 +1046,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name - None, # offload_activation + None, # fine_grained_activation_offloading None, # fsdp_group None, # debug None, # module @@ -1182,7 +1182,7 @@ def __init__( delay_wgrad_compute: bool = False, symmetric_ar_type: Optional[str] = None, name: str = None, - offload_activation: bool = False, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1199,7 +1199,7 @@ def __init__( self.return_layernorm_output_gathered = return_layernorm_output_gathered self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type - self.offload_activation = offload_activation + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name @@ -1602,7 +1602,7 @@ def forward( self.ub_bulk_wgrad, self.ub_bulk_dgrad, self.ub_name, - self.offload_activation, + self.fine_grained_activation_offloading, self.fsdp_group, self, skip_fp8_weight_update, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6d7864dda9..aa449a10a4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -109,7 +109,7 @@ def forward( ub_bulk_dgrad: bool, ub_bulk_wgrad: bool, ub_name: str, - offload_activation: bool, + fine_grained_activation_offloading: bool, fp8_output: bool, # pylint: disable=unused-argument fsdp_group: Union[dist_group_type, None], module: torch.nn.Module, @@ -396,14 +396,14 @@ def forward( ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") - ctx.offload_activation = offload_activation + ctx.fine_grained_activation_offloading = fine_grained_activation_offloading - if offload_activation and cpu_offloading: + if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use offload_activation and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." ) - if offload_activation and weight.requires_grad and fuse_wgrad_accumulation: + if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad @@ -515,10 +515,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else None ) - if ctx.cpu_offloading or ctx.offload_activation: + if ctx.cpu_offloading or ctx.fine_grained_activation_offloading: if ctx.has_grad_added_to_main_grad: weight = ctx.weight_object - if ctx.offload_activation: + if ctx.fine_grained_activation_offloading: weight.grad_added_to_main_grad = ctx.grad_added_to_main_grad if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: weight.main_grad = main_grad @@ -992,7 +992,7 @@ def wgrad_gemm( None, # ub_bulk_dgrad None, # ub_bulk_wgrad None, # ub_name - None, # offload_activation + None, # fine_grained_activation_offloading None, # fp8_output None, # fsdp_group None, # module @@ -1115,7 +1115,7 @@ def __init__( symmetric_ar_type: Optional[str] = None, save_original_input: bool = False, name: Optional[str] = None, - offload_activation: bool = False, + fine_grained_activation_offloading: bool = False, ) -> None: super().__init__() @@ -1131,7 +1131,7 @@ def __init__( self.symmetric_ar_type = symmetric_ar_type self.save_original_input = save_original_input self.name = name - self.offload_activation = offload_activation + self.fine_grained_activation_offloading = fine_grained_activation_offloading self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1478,7 +1478,7 @@ def forward( self.ub_bulk_dgrad, self.ub_bulk_wgrad, self.ub_name, - self.offload_activation, + self.fine_grained_activation_offloading, fp8_output, self.fsdp_group, self, From a1c6e073fbffd74516a8b6a5611a926b86fa0622 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 22 Sep 2025 04:34:56 -0700 Subject: [PATCH 5/6] minor fix for fp8 Signed-off-by: Hongbin Liu --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1361c4c217..311475818a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -827,10 +827,10 @@ def forward( self.sequence_parallel, self.activation_dtype, torch.is_grad_enabled(), - self.fine_grained_activation_offloading, self, skip_fp8_weight_update, self.save_original_input, + self.fine_grained_activation_offloading, *weight_tensors, *bias_tensors, ) From 37250dbf4eb3ba94613b7aa7f281f5f72bd383a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:07:16 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 13 ++++++++++--- .../pytorch/module/layernorm_linear.py | 9 +++++++-- transformer_engine/pytorch/module/linear.py | 9 +++++++-- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 311475818a..247f186af1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -219,10 +219,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weights[0].requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weights[0].requires_grad + and fuse_wgrad_accumulation + ): grad_added_to_main_grad_list = [] for weight in weights: if weight.requires_grad and hasattr(weight, "grad_added_to_main_grad"): @@ -292,7 +297,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if (ctx.cpu_offloading or ctx.fine_grained_activation_offloading) and ctx.fuse_wgrad_accumulation: + if ( + ctx.cpu_offloading or ctx.fine_grained_activation_offloading + ) and ctx.fuse_wgrad_accumulation: for i in range(ctx.num_gemms): if not ctx.cpu_offloading: w = torch.nn.Parameter(weights[i], weights[i].requires_grad) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 55e226066a..abb96f11a8 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -435,10 +435,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index aa449a10a4..a5ccefbb1d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -400,10 +400,15 @@ def forward( if fine_grained_activation_offloading and cpu_offloading: raise ValueError( - f"Do not use fine_grained_activation_offloading and cpu_offloading at the same time." + f"Do not use fine_grained_activation_offloading and cpu_offloading at the same" + f" time." ) - if fine_grained_activation_offloading and weight.requires_grad and fuse_wgrad_accumulation: + if ( + fine_grained_activation_offloading + and weight.requires_grad + and fuse_wgrad_accumulation + ): if hasattr(weight, "grad_added_to_main_grad"): ctx.has_grad_added_to_main_grad = True ctx.grad_added_to_main_grad = weight.grad_added_to_main_grad