From 74d982d3ad00fdc60ab915ed7936c3ab6d877a15 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 2 Feb 2026 16:45:50 -0800 Subject: [PATCH 01/45] Add NVTE_KEEP_BACKWARD_UNQUANTIZED Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 4 +- .../pytorch/module/grouped_linear.py | 36 +++-- .../pytorch/module/layernorm_linear.py | 80 +++++++--- .../pytorch/module/layernorm_mlp.py | 147 +++++++++++------- transformer_engine/pytorch/module/linear.py | 65 +++++--- .../pytorch/ops/basic/basic_linear.py | 48 ++++-- .../pytorch/ops/basic/quantize.py | 6 +- .../ops/fused/backward_activation_bias.py | 7 +- .../fused/forward_linear_bias_activation.py | 18 ++- .../ops/fused/forward_linear_bias_add.py | 18 ++- .../ops/fused/forward_linear_scale_add.py | 18 ++- .../ops/fused/userbuffers_forward_linear.py | 49 +++++- transformer_engine/pytorch/ops/fuser.py | 16 +- transformer_engine/pytorch/quantization.py | 5 + 14 files changed, 375 insertions(+), 142 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 09b12afa21..48b02acb01 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,9 +1135,11 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. - if not ctx.fp8 and not ctx.debug: + if not use_fp8_bwd and not ctx.debug: if gather_grad_output: if not ctx.ub_overlap_ag: # Perform NCCL all-gather grad_output, _ = gather_along_first_dim(grad_output, ctx.tp_group) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2f859e748b..d214ce6a54 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,6 +98,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True num_gemms = len(m_splits) weights = weights_and_biases[:num_gemms] @@ -291,6 +294,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -299,7 +303,11 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, weights[0], biases[0]) + ): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() @@ -323,6 +331,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -338,7 +348,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -392,7 +402,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8 or ctx.debug: + if use_fp8_bwd or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -403,13 +413,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + weights_for_dgrad = weights if use_fp8_bwd else origin_weights + if use_fp8_bwd: + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( - weights, + weights_for_dgrad, grad_output, [dgrad], ctx.grad_input_quantizers, @@ -423,7 +435,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -451,7 +463,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if ctx.fp8 and not ctx.debug: + if use_fp8_bwd and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -528,7 +540,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not ctx.fp8 + and not use_fp8_bwd ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..28842fc315 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,6 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -200,7 +201,10 @@ def forward( if fp8: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and not keep_backward_unquantized, + ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -213,6 +217,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -236,6 +241,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -409,13 +415,14 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input: + if backward_needs_input and not keep_backward_unquantized: if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -427,7 +434,7 @@ def forward( ln_out.update_usage(rowwise_usage=False) if cpu_offloading: - mark_activation_offload(inputmat, mu, rsigma, ln_out) + mark_activation_offload(inputmat, mu, rsigma, ln_out_to_save) # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -439,7 +446,7 @@ def forward( mu, rsigma, weightmat if fp8 and not is_weight_param_quantized else None, - ln_out if weight.requires_grad else None, + ln_out_to_save if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -466,7 +473,7 @@ def forward( weight, bias, ln_weight, - ln_out, + ln_out_to_save, mu, rsigma, ) @@ -493,6 +500,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -515,7 +523,11 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad(inp, ln_weight, ln_bias, weight, bias) + ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -592,6 +604,15 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -601,23 +622,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -628,7 +649,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -665,7 +686,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None: + if ctx.input_quantizer is not None and use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -703,18 +724,22 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight, QuantizedTensorStorage) + ): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -730,12 +755,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight if use_quantized_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -782,7 +808,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -794,7 +824,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -820,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -836,7 +866,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -862,7 +892,9 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -870,7 +902,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4532ea60e7..e08561617e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,6 +235,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -353,8 +354,10 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = fc1_weight.requires_grad and ( - (is_grad_enabled and not checkpoint) or is_recomputation + backwards_needs_fc1_input = ( + fc1_weight.requires_grad + and ((is_grad_enabled and not checkpoint) or is_recomputation) + and not keep_backward_unquantized ) device = inp.device @@ -397,6 +400,7 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not keep_backward_unquantized and not custom ) @@ -418,6 +422,7 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out + ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -614,6 +619,10 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) + act_out_hp = act_out + if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: + act_out_hp = activation_func(fc1_out, None, **act_params) + # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -689,22 +698,30 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: + ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out) - ln_out = None + clear_tensor_data(ln_out_to_save) + ln_out_to_save = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out) - act_out = None + clear_tensor_data(act_out_to_save) + act_out_to_save = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out + inputmat, + mu, + rsigma, + ln_out_to_save, + fc1_out, + fc1_out_without_bias, + act_out_to_save, ) # Scatter intermediate/activation tensors saved for the backward pass @@ -717,9 +734,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out, + ln_out_to_save, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out, + act_out_to_save, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -747,13 +764,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out, + ln_out_to_save, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out, + act_out_to_save, fc2_weight_final, fc2_weight, fc2_bias, @@ -801,6 +818,7 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -829,8 +847,12 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + if ( + ctx.fp8 + and not ctx.keep_backward_unquantized + and requires_grad( + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() @@ -999,6 +1021,16 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1018,7 +1050,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1032,7 +1064,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None: + if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1045,7 +1077,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1060,7 +1092,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1069,7 +1101,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1106,7 +1138,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not ctx.fp8 + not use_fp8_bwd and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1115,20 +1147,23 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc2_weight_quantizer is not None + and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM + fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight, + fc2_weight_for_dgrad, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if fc2_dgrad_gemm_gelu_fusion or ctx.debug + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1160,7 +1195,11 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1173,7 +1212,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1196,14 +1235,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1212,7 +1251,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): + if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1222,7 +1261,9 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision + "quantization_params": ( + ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1259,8 +1300,8 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - ctx.fp8 - and ctx.fp8_recipe.float8_block_scaling() + use_fp8_bwd + and fp8_recipe_bwd.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1280,12 +1321,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None: + if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not ctx.fp8 + assert not use_fp8_bwd fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1295,13 +1336,10 @@ def fc2_wgrad_gemm( fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) elif ( - _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and ctx.fp8 + _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[2] + dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1311,18 +1349,16 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func( - ctx.activation, ctx.fp8_recipe if ctx.fp8 else None - )[1] + activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if ctx.fp8: + if use_fp8_bwd: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or ctx.fp8_recipe.custom() + or fp8_recipe_bwd.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) @@ -1350,16 +1386,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -1367,8 +1403,10 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.fc1_weight_quantizer is not None + and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1383,12 +1421,13 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM + fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight, + fc1_weight_for_dgrad, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer, + quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1437,7 +1476,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1447,7 +1486,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1469,7 +1508,9 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.fc1_grad_weight_quantizer, + "quantization_params": ( + ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..b4bad849c1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,6 +129,9 @@ def forward( save_original_input, debug, ) = non_tensor_args + keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + if keep_backward_unquantized: + save_original_input = True # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" @@ -443,6 +446,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -479,7 +483,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): @@ -536,6 +540,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") + keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) + use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_quantized_bwd = use_fp8_bwd or ctx.debug + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -545,23 +558,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -575,7 +588,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None: + if ctx.grad_output_quantizer is not None and use_quantized_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -594,6 +607,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None + and use_quantized_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -623,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -649,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -690,20 +704,22 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance( - weight_fp8, QuantizedTensorStorage + if ( + use_quantized_bwd + and ctx.weight_quantizer is not None + and isinstance(weight_fp8, QuantizedTensorStorage) ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None: + if ctx.grad_input_quantizer is not None and use_quantized_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -720,12 +736,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") + weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( - weight_fp8, + weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer, + quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -774,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -784,7 +801,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): + if ( + use_fp8_bwd + and ctx.ub_overlap_ag + and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) + ): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -796,7 +817,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -816,7 +837,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if use_quantized_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -825,7 +846,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: + if use_fp8_bwd: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -851,7 +872,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ctx.grad_weight_quantizer, + "quantization_params": ( + ctx.grad_weight_quantizer if use_quantized_bwd else None + ), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -859,7 +882,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 48376a297f..f2b8ba106e 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,12 +332,14 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad + keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) weight_quantizer.set_usage(rowwise=True, columnwise=False) - grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + grad_output_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -420,6 +422,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -459,6 +462,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -510,7 +515,10 @@ def _functional_forward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) x, x_async = gather_along_first_dim( @@ -542,7 +550,10 @@ def _functional_forward( elif with_quantized_compute and not is_quantized_tensor(w): if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Check output tensor @@ -611,14 +622,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -968,6 +988,9 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -984,6 +1007,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -993,10 +1017,16 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = self.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + ctx.save_for_backward(saved_input, saved_weight) + ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index fa3efc3807..7dd8f1a7ac 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,7 +57,11 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = fp8_enabled and self._quantize_backward + quantize_backward = ( + fp8_enabled + and self._quantize_backward + and not FP8GlobalStateManager.keep_backward_unquantized() + ) # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..59e9af14f4 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe +from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,7 +105,10 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None: + if recipe is None or ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index dfc11a19e7..0a28d00706 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,6 +92,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,6 +112,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -118,10 +122,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 2dfc0566b7..41ae096e54 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,6 +86,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -106,6 +109,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -115,10 +119,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index ae4bdd4b19..b06f5ad36a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,6 +65,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -87,6 +90,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -96,10 +100,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 0d3e1d0416..3e5a389246 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,6 +94,7 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, + keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -126,6 +127,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. + keep_backward_unquantized: bool, default = `False` + Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -200,7 +203,10 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -216,7 +222,10 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.set_usage( + rowwise=True, + columnwise=weight_requires_grad and not keep_backward_unquantized, + ) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -227,7 +236,10 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + weight_quantizer.set_usage( + rowwise=True, + columnwise=input_requires_grad and not keep_backward_unquantized, + ) w = weight_quantizer(w) # Construct output tensor if needed @@ -257,14 +269,23 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if w is not weight and with_quantized_compute and is_quantized_tensor(w): + if ( + w is not weight + and with_quantized_compute + and is_quantized_tensor(w) + and not keep_backward_unquantized + ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if with_quantized_compute and is_quantized_tensor(x_local): + if ( + with_quantized_compute + and is_quantized_tensor(x_local) + and not keep_backward_unquantized + ): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -311,6 +332,9 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + keep_backward_unquantized = ( + with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -340,6 +364,7 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, + keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -352,10 +377,18 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + saved_input = input_ if keep_backward_unquantized else x_local + if not weight_requires_grad: + saved_input = None + saved_weight = linear_op.weight if keep_backward_unquantized else w + if not input_requires_grad: + saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(x_local) - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute + mark_activation_offload(saved_input) + linear_op_ctx.save_for_backward(saved_input, saved_weight) + linear_op_ctx.with_quantized_compute = ( + with_quantized_compute and not keep_backward_unquantized + ) linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index bd3bc94b60..465091aecb 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,6 +109,10 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -120,7 +124,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None: + if prev_op is not None and not keep_backward_unquantized: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -286,7 +290,15 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - if func_ctx.is_first_module and not _is_graph_capturing(): + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() + and FP8GlobalStateManager.keep_backward_unquantized() + ) + if ( + func_ctx.is_first_module + and not keep_backward_unquantized + and not _is_graph_capturing() + ): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index eba547afb0..9806871ef6 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -430,6 +430,11 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL + @classmethod + def keep_backward_unquantized(cls) -> bool: + """Should backward skip FP8 quantization and use high precision""" + return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" From f04ae5230b18fe5076d7a602bdfb692ad67666bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 00:49:22 +0000 Subject: [PATCH 02/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/ops/fuser.py | 6 +----- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e08561617e..3b5b3bca9f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1335,9 +1335,7 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif ( - _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd - ): + elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] fc1_bias_grad, dact = dbias_dact_quantize_func( diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 465091aecb..7b31811869 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -294,11 +294,7 @@ def backward( FP8GlobalStateManager.is_fp8_enabled() and FP8GlobalStateManager.keep_backward_unquantized() ) - if ( - func_ctx.is_first_module - and not keep_backward_unquantized - and not _is_graph_capturing() - ): + if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From 28fadbc9d75d2cfdef20e0ca0063e27d736d6ee3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:36:13 -0800 Subject: [PATCH 03/45] Disable ub and clean up Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 9 ++-- .../pytorch/module/layernorm_mlp.py | 13 ++--- transformer_engine/pytorch/module/linear.py | 17 +++---- .../ops/fused/userbuffers_forward_linear.py | 49 +++---------------- 4 files changed, 25 insertions(+), 63 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 28842fc315..66e67522f6 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -608,6 +608,7 @@ def backward( use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -622,23 +623,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 3b5b3bca9f..ca65b57b4a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1026,6 +1026,7 @@ def backward( use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -1077,7 +1078,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad", use_fp8_bwd) + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -1101,7 +1102,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -1195,11 +1196,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -1212,7 +1209,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad", use_fp8_bwd) + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b4bad849c1..a03e9ac4d5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -544,6 +544,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -558,23 +559,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", use_fp8_bwd) + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -801,11 +802,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -817,7 +814,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 3e5a389246..0d3e1d0416 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -94,7 +94,6 @@ def _functional_forward( tensor_parallel_size: Optional[int] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - keep_backward_unquantized: bool = False, input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -127,8 +126,6 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. - keep_backward_unquantized: bool, default = `False` - Whether to skip quantized backward and use high precision. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -203,10 +200,7 @@ def _functional_forward( if with_ub_all_gather: if input_quantizer is not None: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -222,10 +216,7 @@ def _functional_forward( else: if with_quantized_compute: if not is_quantized_tensor(x_local): - input_quantizer.set_usage( - rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, - ) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) x_local = input_quantizer(x_local) else: x_local = maybe_dequantize(x_local, dtype) @@ -236,10 +227,7 @@ def _functional_forward( if not with_quantized_compute: w = maybe_dequantize(w, dtype) elif with_quantized_compute and not is_quantized_tensor(w): - weight_quantizer.set_usage( - rowwise=True, - columnwise=input_requires_grad and not keep_backward_unquantized, - ) + weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) w = weight_quantizer(w) # Construct output tensor if needed @@ -269,23 +257,14 @@ def _functional_forward( # Prepare weight tensor for backward pass if input_requires_grad: - if ( - w is not weight - and with_quantized_compute - and is_quantized_tensor(w) - and not keep_backward_unquantized - ): + if w is not weight and with_quantized_compute and is_quantized_tensor(w): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: w = None # Prepare input tensor for backward pass if weight_requires_grad: - if ( - with_quantized_compute - and is_quantized_tensor(x_local) - and not keep_backward_unquantized - ): + if with_quantized_compute and is_quantized_tensor(x_local): if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) @@ -332,9 +311,6 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() - ) if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): @@ -364,7 +340,6 @@ def fuser_forward( tensor_parallel_size=self.tensor_parallel_size, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=None, # Not supported @@ -377,18 +352,10 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): - mark_activation_offload(saved_input) - linear_op_ctx.save_for_backward(saved_input, saved_weight) - linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized - ) + mark_activation_offload(x_local) + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer From 8d7cbbbdd94323aebe74087cd0af3df199146c59 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:37:57 -0800 Subject: [PATCH 04/45] Drop fuser changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/fuser.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 7b31811869..bd3bc94b60 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -109,10 +109,6 @@ def forward( # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) for op, basic_op_idxs in fuser._forward_ops: # Set if backward op is required @@ -124,7 +120,7 @@ def forward( prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None prev_op_grad_output_quantizer = None - if prev_op is not None and not keep_backward_unquantized: + if prev_op is not None: prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None @@ -290,11 +286,7 @@ def backward( grad_extra_inputs_flat.extend(dxs) # Update FP8 scaling factors - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ) - if func_ctx.is_first_module and not keep_backward_unquantized and not _is_graph_capturing(): + if func_ctx.is_first_module and not _is_graph_capturing(): FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) return ( From 7eda43339812c1e4cdf2cfc5d81a44be7ee109d9 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 09:56:43 -0800 Subject: [PATCH 05/45] Replace use_quantized_bwd with use_fp8_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 19 +++++++------ .../pytorch/module/layernorm_mlp.py | 27 +++++++++---------- transformer_engine/pytorch/module/linear.py | 23 ++++++++-------- 3 files changed, 33 insertions(+), 36 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 66e67522f6..b759c152ec 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -606,7 +606,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -650,7 +649,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -687,7 +686,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_quantized_bwd: + if ctx.input_quantizer is not None and use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -726,7 +725,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -740,7 +739,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -756,13 +755,13 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_quantized_bwd else origin_weight + weight_for_dgrad = weight if use_fp8_bwd else origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -851,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -894,7 +893,7 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ca65b57b4a..0e4d06e86b 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1023,7 +1023,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True @@ -1065,7 +1064,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1093,7 +1092,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1149,7 +1148,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensorStorage) ): @@ -1164,7 +1163,7 @@ def backward( grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_quantized_bwd + if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1232,14 +1231,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1259,7 +1258,7 @@ def backward( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None ), # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad @@ -1318,7 +1317,7 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_quantized_bwd: + if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu @@ -1399,7 +1398,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ( - use_quantized_bwd + use_fp8_bwd and ctx.fc1_weight_quantizer is not None and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) ): @@ -1422,7 +1421,7 @@ def fc2_wgrad_gemm( dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1471,7 +1470,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1481,7 +1480,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1504,7 +1503,7 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_quantized_bwd else None + ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a03e9ac4d5..6ecc647626 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -542,7 +542,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - use_quantized_bwd = use_fp8_bwd or ctx.debug if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -589,7 +588,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_quantized_bwd: + if ctx.grad_output_quantizer is not None and use_fp8_bwd: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -608,7 +607,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_quantized_bwd + and use_fp8_bwd ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -638,7 +637,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -664,7 +663,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_quantized_bwd: + if use_fp8_bwd: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -706,7 +705,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_quantized_bwd + use_fp8_bwd and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -720,7 +719,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_quantized_bwd: + if ctx.grad_input_quantizer is not None and use_fp8_bwd: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -737,13 +736,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_quantized_bwd else weight + weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_quantized_bwd else None, + quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -792,7 +791,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -834,7 +833,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_quantized_bwd: + if use_fp8_bwd: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -870,7 +869,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ( - ctx.grad_weight_quantizer if use_quantized_bwd else None + ctx.grad_weight_quantizer if use_fp8_bwd else None ), "accumulate": ( accumulate_wgrad_into_param_main_grad From dbc60c52eab99694299084793d2456ed4b3972b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:57:32 +0000 Subject: [PATCH 06/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 3 files changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b759c152ec..bdfeff056b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -892,9 +892,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 0e4d06e86b..ef539d60ad 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1502,9 +1502,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6ecc647626..1ce4fac445 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -868,9 +868,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.grad_weight_quantizer if use_fp8_bwd else None - ), + "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From 40d42d4f01f5ec73811d738828e874f0d19ed1e5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:30:04 -0800 Subject: [PATCH 07/45] Ignore keep_backward_unquantized if delayed scaling Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 1 + transformer_engine/pytorch/module/linear.py | 1 + transformer_engine/pytorch/quantization.py | 3 +++ 3 files changed, 5 insertions(+) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d214ce6a54..99a80754f7 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -100,6 +100,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1ce4fac445..49b78382d2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,6 +131,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: + # Note, keep_backward_unquantized is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9806871ef6..e8f6dafdb5 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -433,6 +433,9 @@ def with_high_precision_init_val(cls) -> bool: @classmethod def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" + recipe = cls.get_fp8_recipe() + if recipe.delayed(): + return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) @classmethod From f87e17b9be461a1d9231a9a32f6157b1c4635bfb Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 10:39:02 -0800 Subject: [PATCH 08/45] Refactor ignoring NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/quantization.py | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 99a80754f7..d7df41d3c3 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -100,7 +100,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True num_gemms = len(m_splits) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 49b78382d2..0bf560c7b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -131,7 +131,7 @@ def forward( ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() if keep_backward_unquantized: - # Note, keep_backward_unquantized is ignored when delayed scaling is used + # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True # NVTX label for profiling diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index e8f6dafdb5..aab7ed2d1c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -434,7 +434,8 @@ def with_high_precision_init_val(cls) -> bool: def keep_backward_unquantized(cls) -> bool: """Should backward skip FP8 quantization and use high precision""" recipe = cls.get_fp8_recipe() - if recipe.delayed(): + if recipe is not None and recipe.delayed(): + # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used return False return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) From 455a905b52572eb88ac590a0748a3bc026a95540 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:13:57 -0800 Subject: [PATCH 09/45] Add back missing ctx.debug Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/layernorm_mlp.py | 10 +++++----- transformer_engine/pytorch/module/linear.py | 8 ++++---- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index bdfeff056b..fd458a34b4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -850,14 +850,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ef539d60ad..93a3606de8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1092,7 +1092,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1231,14 +1231,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1470,7 +1470,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1480,7 +1480,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0bf560c7b7..930fbe061d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -638,7 +638,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -664,7 +664,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -792,7 +792,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -834,7 +834,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd: + if use_fp8_bwd or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: From 41415ff79873e3129d6f53ca7bd2cef593f37979 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:43:45 -0800 Subject: [PATCH 10/45] Refactor changes under fused Signed-off-by: Ziang Li --- .../ops/fused/backward_activation_bias.py | 7 ++----- .../ops/fused/forward_linear_bias_activation.py | 17 +++++++++++------ .../ops/fused/forward_linear_bias_add.py | 17 +++++++++++------ .../ops/fused/forward_linear_scale_add.py | 17 +++++++++++------ 4 files changed, 35 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 59e9af14f4..4ab082d32b 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.quantization import Recipe, FP8GlobalStateManager +from transformer_engine.pytorch.quantization import Recipe from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, @@ -105,10 +105,7 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion - if recipe is None or ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.keep_backward_unquantized() - ): + if recipe is None: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 0a28d00706..6e7c85988f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,12 +122,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 41ae096e54..f3b4533848 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,12 +119,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index b06f5ad36a..53e7327873 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,12 +100,17 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None - saved_weight = linear_op.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None + saved_input = x_local + saved_weight = w + if keep_backward_unquantized: + saved_input = input_ if input_requires_grad else None + saved_weight = linear_op.weight if weight_requires_grad else None + # saved_input = input_ if keep_backward_unquantized else x_local + # if not weight_requires_grad: + # saved_input = None + # saved_weight = linear_op.weight if keep_backward_unquantized else w + # if not input_requires_grad: + # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From 65d44ff3893834d8b880b89eb198472ef688f1c8 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 11:44:30 -0800 Subject: [PATCH 11/45] Clean up Signed-off-by: Ziang Li --- .../pytorch/ops/fused/forward_linear_bias_activation.py | 6 ------ .../pytorch/ops/fused/forward_linear_bias_add.py | 6 ------ .../pytorch/ops/fused/forward_linear_scale_add.py | 6 ------ 3 files changed, 18 deletions(-) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 6e7c85988f..2458d4d072 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -127,12 +127,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index f3b4533848..efa543e555 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -124,12 +124,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 53e7327873..2804534968 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -105,12 +105,6 @@ def fuser_forward( if keep_backward_unquantized: saved_input = input_ if input_requires_grad else None saved_weight = linear_op.weight if weight_requires_grad else None - # saved_input = input_ if keep_backward_unquantized else x_local - # if not weight_requires_grad: - # saved_input = None - # saved_weight = linear_op.weight if keep_backward_unquantized else w - # if not input_requires_grad: - # saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From e3a651cd1c550d22037e8bd0afe441b911d135d3 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:11:07 -0800 Subject: [PATCH 12/45] Refactor high-precision overwrite if keep_backward_unquantized Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 17 ++++++++++------- .../pytorch/module/layernorm_linear.py | 10 ++++++++-- .../pytorch/module/layernorm_mlp.py | 14 +++++++++++--- transformer_engine/pytorch/module/linear.py | 5 ++++- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index d7df41d3c3..c5908ea24c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -414,13 +414,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - weights_for_dgrad = weights if use_fp8_bwd else origin_weights - if use_fp8_bwd: - # Make sure weights are available in column-wise format - # for dgrad computation. - for weight in weights_for_dgrad: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) + # weights_for_dgrad = weights if use_fp8_bwd else origin_weights + # if use_fp8_bwd: + weights_for_dgrad = weights + if keep_backward_unquantized: + weights_for_dgrad = origin_weights + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights_for_dgrad: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index fd458a34b4..70d8936ce3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -415,7 +415,10 @@ def forward( # ------------------------------------------------------ if is_grad_enabled: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + ln_out_to_save = ln_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -755,7 +758,10 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight if use_fp8_bwd else origin_weight + # weight_for_dgrad = weight if use_fp8_bwd else origin_weight + weight_for_dgrad = weight + if keep_backward_unquantized: + weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 93a3606de8..5dc648d7f2 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -698,8 +698,13 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - act_out_to_save = act_out_hp if keep_backward_unquantized else act_out + ln_out_to_save = ln_out + act_out_to_save = act_out + if keep_backward_unquantized: + ln_out_to_save = ln_out_hp + act_out_to_save = act_out_hp + # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out + # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1155,7 +1160,10 @@ def backward( ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight + fc2_weight_for_dgrad = fc2_weight + if keep_backward_unquantized: + fc2_weight_for_dgrad = origin_fc2_weight + # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 930fbe061d..496bfd45b7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -737,7 +737,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight + weight_for_dgrad = weight_fp8 + if keep_backward_unquantized: + weight_for_dgrad = weight + # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From fba242b89303cfde58cdae81ec84b4f4f70c4972 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 3 Feb 2026 14:14:22 -0800 Subject: [PATCH 13/45] Clean up Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_linear.py | 2 -- transformer_engine/pytorch/module/layernorm_mlp.py | 3 --- transformer_engine/pytorch/module/linear.py | 1 - 4 files changed, 8 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c5908ea24c..f14512bb06 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -414,8 +414,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - # weights_for_dgrad = weights if use_fp8_bwd else origin_weights - # if use_fp8_bwd: weights_for_dgrad = weights if keep_backward_unquantized: weights_for_dgrad = origin_weights diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 70d8936ce3..e3aab9b304 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -418,7 +418,6 @@ def forward( ln_out_to_save = ln_out if keep_backward_unquantized: ln_out_to_save = ln_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel @@ -758,7 +757,6 @@ def backward( # dgrad GEMM # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") - # weight_for_dgrad = weight if use_fp8_bwd else origin_weight weight_for_dgrad = weight if keep_backward_unquantized: weight_for_dgrad = origin_weight diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 5dc648d7f2..d7732cc396 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -703,8 +703,6 @@ def _forward( if keep_backward_unquantized: ln_out_to_save = ln_out_hp act_out_to_save = act_out_hp - # ln_out_to_save = ln_out_hp if keep_backward_unquantized else ln_out - # act_out_to_save = act_out_hp if keep_backward_unquantized else act_out ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer @@ -1163,7 +1161,6 @@ def backward( fc2_weight_for_dgrad = fc2_weight if keep_backward_unquantized: fc2_weight_for_dgrad = origin_fc2_weight - # fc2_weight_for_dgrad = fc2_weight if use_fp8_bwd else origin_fc2_weight gemm_output, *_ = general_gemm( fc2_weight_for_dgrad, grad_output, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 496bfd45b7..10ea095c16 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -740,7 +740,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], weight_for_dgrad = weight_fp8 if keep_backward_unquantized: weight_for_dgrad = weight - # weight_for_dgrad = weight_fp8 if use_fp8_bwd else weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, From d25fc4774e435874c81ec39e631beca08f78f48c Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 10:56:41 -0800 Subject: [PATCH 14/45] Drop redundant fp8_recipe_bwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d7732cc396..c3962ab529 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1026,7 +1026,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - fp8_recipe_bwd = ctx.fp8_recipe if use_fp8_bwd else None if keep_backward_unquantized: # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True ctx.ub_overlap_ag = False @@ -1252,7 +1251,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and fp8_recipe_bwd.float8_block_scaling(): + if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1302,7 +1301,7 @@ def fc2_wgrad_gemm( if fc2_bias_grad is None: if ( use_fp8_bwd - and fp8_recipe_bwd.float8_block_scaling() + and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): # BGRAD not fused with GEMM for float8 blockwise gemm. @@ -1336,9 +1335,14 @@ def fc2_wgrad_gemm( dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) - elif _act_func(ctx.activation, fp8_recipe_bwd)[2] is not None and use_fp8_bwd: + elif ( + _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None + and use_fp8_bwd + ): # Fusion: gemm, bias + gelu + quantize - dbias_dact_quantize_func = _act_func(ctx.activation, fp8_recipe_bwd)[2] + dbias_dact_quantize_func = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( fc2_dgrad, fc1_out.to(ctx.activation_dtype), @@ -1348,7 +1352,9 @@ def fc2_wgrad_gemm( else: # Fusion: gemm + gelu, if not fc2_dgrad_gemm_gelu_fusion: - activation_func_bwd = _act_func(ctx.activation, fp8_recipe_bwd)[1] + activation_func_bwd = _act_func( + ctx.activation, ctx.fp8_recipe if ctx.fp8 else None + )[1] dact = activation_func_bwd( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision @@ -1357,7 +1363,7 @@ def fc2_wgrad_gemm( # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) - or fp8_recipe_bwd.custom() + or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) From 4df62fa9c8ee8df1b8131f824a9c9d03a34ffffb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 18:57:29 +0000 Subject: [PATCH 15/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c3962ab529..e8063e9d76 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1338,7 +1338,7 @@ def fc2_wgrad_gemm( elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None and use_fp8_bwd - ): + ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None From 06e70604f64acd1d9938bd810a3ce2e691ce26e0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:02:24 -0800 Subject: [PATCH 16/45] Drop redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index e8063e9d76..93daf99917 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1391,16 +1391,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad", use_fp8_bwd) + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad", use_fp8_bwd) + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- From e26d3188df3604687d8020af17b12533d6cccca5 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:07:16 -0800 Subject: [PATCH 17/45] Drop more redundant ub changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e3aab9b304..60c4e1d8b2 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -812,11 +812,7 @@ def backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ( - use_fp8_bwd - and ctx.ub_overlap_ag - and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer) - ): + if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -828,7 +824,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", use_fp8_bwd) + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) From 986f173c9538efdeb14c5e45e8a8e1c9afc02ada Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 11:25:01 -0800 Subject: [PATCH 18/45] Drop redundant delayed scaling changes Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- transformer_engine/pytorch/module/layernorm_mlp.py | 6 +----- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f14512bb06..a469753bbf 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -304,11 +304,7 @@ def forward( ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, weights[0], biases[0]) - ): + if ctx.fp8 and requires_grad(inp, weights[0], biases[0]): ctx.reduce_and_update_bwd_fp8_tensors = ( ctx.reduce_and_update_bwd_fp8_tensors or FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 93daf99917..d899879226 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -850,12 +850,8 @@ def _forward( ) ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad( + if ctx.fp8 and requires_grad( inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias - ) ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 10ea095c16..535d2e75e5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -484,7 +484,7 @@ def forward( ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and not ctx.keep_backward_unquantized and requires_grad(inp, weight, bias): + if ctx.fp8 and requires_grad(inp, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 5019d3bb42707cbd65d429ff7544d2cde9b172b7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 19:25:49 +0000 Subject: [PATCH 19/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d899879226..95d6b3e837 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -851,7 +851,7 @@ def _forward( ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False if ctx.fp8 and requires_grad( - inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias + inp, ln_weight, ln_bias, fc1_weight, fc2_weight, fc1_bias, fc2_bias ): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() From 70ef66ca22b1eed05e3a2ca8377d5006d9f78301 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 12:01:36 -0800 Subject: [PATCH 20/45] Drop unneeded backwards_needs_fc1_input Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_mlp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 95d6b3e837..9bb464ab98 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -354,10 +354,8 @@ def _forward( # bwd needs fc1 input when grad is enabled, fc1 needs grad, and either # 1) no checkpointing # or 2) doing the recomputation with checkpointing - backwards_needs_fc1_input = ( - fc1_weight.requires_grad - and ((is_grad_enabled and not checkpoint) or is_recomputation) - and not keep_backward_unquantized + backwards_needs_fc1_input = fc1_weight.requires_grad and ( + (is_grad_enabled and not checkpoint) or is_recomputation ) device = inp.device From 88c58fa4b625dcfe5a84cd8b5dc3ba8984b089ed Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 4 Feb 2026 14:01:43 -0800 Subject: [PATCH 21/45] Drop and disallow LayerNormMLP implementation Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_mlp.py | 104 ++++++------------ 1 file changed, 34 insertions(+), 70 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9bb464ab98..38e56530da 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -236,6 +236,7 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -398,7 +399,6 @@ def _forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not keep_backward_unquantized and not custom ) @@ -420,7 +420,6 @@ def _forward( # do not return layernorm output unless 1) no checkpointing or 2) checkpointing but not recomputing if (return_layernorm_output or return_layernorm_output_gathered) and not is_recomputation: ln_out_return = ln_out - ln_out_hp = ln_out if keep_backward_unquantized else None # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication @@ -617,10 +616,6 @@ def _forward( if fc2_input_quantizer is not None: fc2_input_quantizer.calibrate(act_out) - act_out_hp = act_out - if keep_backward_unquantized and is_grad_enabled and fc1_out is not None: - act_out_hp = activation_func(fc1_out, None, **act_params) - # we want to skip fc2 computation if we are checkpointing and recomputing, # otherwise we compute fc2 if not (is_recomputation and checkpoint): @@ -696,33 +691,22 @@ def _forward( # if we are not checkpointing, then we must save this if grad is enabled if is_grad_enabled and not save_for_checkpoint: - ln_out_to_save = ln_out - act_out_to_save = act_out - if keep_backward_unquantized: - ln_out_to_save = ln_out_hp - act_out_to_save = act_out_hp ctx.fc1_weight_quantizer = fc1_weight_quantizer ctx.fc2_weight_quantizer = fc2_weight_quantizer if not fc1_weight.requires_grad: if not return_layernorm_output: - clear_tensor_data(ln_out_to_save) - ln_out_to_save = None + clear_tensor_data(ln_out) + ln_out = None if not fc2_weight.requires_grad: - clear_tensor_data(act_out_to_save) - act_out_to_save = None + clear_tensor_data(act_out) + act_out = None if not checkpoint: # regular path, no selective activation checkpointing if cpu_offloading: mark_activation_offload( - inputmat, - mu, - rsigma, - ln_out_to_save, - fc1_out, - fc1_out_without_bias, - act_out_to_save, + inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out ) # Scatter intermediate/activation tensors saved for the backward pass @@ -735,9 +719,9 @@ def _forward( fsdp_group, mu, rsigma, - ln_out_to_save, + ln_out, fc1_out_without_bias if bias_gelu_fusion else fc1_out, - act_out_to_save, + act_out, ( fc1_weight_final if fp8 and not isinstance(fc1_weight, Float8Tensor) @@ -765,13 +749,13 @@ def _forward( tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, - ln_out_to_save, + ln_out, fc1_weight_final, fc1_weight, fc1_bias, fc1_out, fc1_out_without_bias, - act_out_to_save, + act_out, fc2_weight_final, fc2_weight, fc2_bias, @@ -819,7 +803,6 @@ def _forward( ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -1018,15 +1001,6 @@ def backward( origin_fc1_weight.main_grad = fc1_weight_main_grad origin_fc2_weight.main_grad = fc2_weight_main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - # TODO: Fix this # pylint: disable=fixme # Gather saved autograd context tensors when running with FSDP # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -1046,7 +1020,7 @@ def backward( # Choose whether to use GEMM kernel with split accumulator dgrad_use_split_accumulator = _2X_ACC_DGRAD wgrad_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -1060,7 +1034,7 @@ def backward( # Configure quantizer for FC2 grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.fc2_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc2_grad_output_quantizer is not None: quantizer = ctx.fc2_grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -1088,7 +1062,7 @@ def backward( ub_obj_fc1_dgrad = None if ctx.fc1_weight_requires_grad and ctx.tensor_parallel and ctx.sequence_parallel: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.fc1_input_quantizer if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # If data is in FP8, we compute FP8 transposes manually @@ -1134,7 +1108,7 @@ def backward( # 5 high-precision unfused: gemm, activation, FC1_bias + FC1_gemm # 6 fp8 unfused: gemm, activation, FC1_bias + FC1_gemm fc2_dgrad_gemm_gelu_fusion = ( - not use_fp8_bwd + not ctx.fp8 and (ctx.activation == "gelu") and (not ctx.bias_gelu_fusion) and (not ctx.debug) @@ -1143,25 +1117,20 @@ def backward( # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ( - use_fp8_bwd - and ctx.fc2_weight_quantizer is not None - and isinstance(ctx.fc2_weight, QuantizedTensorStorage) + if ctx.fc2_weight_quantizer is not None and isinstance( + ctx.fc2_weight, QuantizedTensorStorage ): ctx.fc2_weight.update_usage(columnwise_usage=True) # Perform GEMM - fc2_weight_for_dgrad = fc2_weight - if keep_backward_unquantized: - fc2_weight_for_dgrad = origin_fc2_weight gemm_output, *_ = general_gemm( - fc2_weight_for_dgrad, + fc2_weight, grad_output, layout="NN", grad=True, quantization_params=( ctx.fc1_grad_input_quantizer - if (fc2_dgrad_gemm_gelu_fusion or ctx.debug) and use_fp8_bwd + if fc2_dgrad_gemm_gelu_fusion or ctx.debug else None ), # high precision to activation out_dtype=ctx.activation_dtype, @@ -1229,14 +1198,14 @@ def backward( # Prepare input tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -1245,7 +1214,7 @@ def backward( # Whether to set grad arg in general_gemm grad_arg = True - if use_fp8_bwd and ctx.fp8_recipe.float8_block_scaling(): + if ctx.fp8 and ctx.fp8_recipe.float8_block_scaling(): grad_arg = False # Arguments to include in wgrad GEMM closure @@ -1255,9 +1224,7 @@ def backward( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": ( - ctx.fc2_grad_weight_quantizer if use_fp8_bwd else None - ), # wgrad in high precision + "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc1_weight, "overwrite_main_grad", False) @@ -1294,7 +1261,7 @@ def fc2_wgrad_gemm( # Update grad bias if needed if fc2_bias_grad is None: if ( - use_fp8_bwd + ctx.fp8 and ctx.fp8_recipe.float8_block_scaling() and fc2_bias is not None ): @@ -1315,12 +1282,12 @@ def fc2_wgrad_gemm( act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False - if ctx.fc1_grad_output_quantizer is not None and use_fp8_bwd: + if ctx.fc1_grad_output_quantizer is not None: ctx.fc1_grad_output_quantizer.set_usage(rowwise=True, columnwise=True) if ctx.bias_gelu_fusion: # Fusion: gemm, bias + gelu assert ctx.activation == "gelu" - assert not use_fp8_bwd + assert not ctx.fp8 fc1_bias_grad, dact = bgrad_dgelu_fused(fc2_dgrad, fc1_out_without_bias, fc1_bias) if ctx.fc1_grad_output_quantizer is not None: dact = ctx.fc1_grad_output_quantizer(dact) @@ -1331,7 +1298,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ( _act_func(ctx.activation, ctx.fp8_recipe if ctx.fp8 else None)[2] is not None - and use_fp8_bwd + and ctx.fp8 ): # Fusion: gemm, bias + gelu + quantize dbias_dact_quantize_func = _act_func( @@ -1353,7 +1320,7 @@ def fc2_wgrad_gemm( fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision - if use_fp8_bwd: + if ctx.fp8: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now if ( isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) @@ -1402,10 +1369,8 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # Make sure required data is available - if ( - use_fp8_bwd - and ctx.fc1_weight_quantizer is not None - and isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) + if ctx.fc1_weight_quantizer is not None and isinstance( + ctx.fc1_weight_quantizer, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1420,13 +1385,12 @@ def fc2_wgrad_gemm( gemm_out = ub_obj_fc1_wgrad.get_buffer(local_chunk=False) # dgrad GEMM - fc1_weight_for_dgrad = fc1_weight if use_fp8_bwd else origin_fc1_weight gemm_out, *_, reduce_scatter_out = general_gemm( - fc1_weight_for_dgrad, + fc1_weight, dact, out=gemm_out, out_dtype=ctx.activation_dtype, - quantization_params=ctx.fc1_grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.fc1_grad_input_quantizer, layout="NN", grad=True, use_split_accumulator=dgrad_use_split_accumulator, @@ -1475,7 +1439,7 @@ def fc2_wgrad_gemm( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: @@ -1485,7 +1449,7 @@ def fc2_wgrad_gemm( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: @@ -1507,7 +1471,7 @@ def fc2_wgrad_gemm( if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.fc1_grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": ctx.fc1_grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(fc2_weight, "overwrite_main_grad", False) From a097f3e4272b10e5cefb4ee03a257c5160d1b20e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 4 Feb 2026 22:02:31 +0000 Subject: [PATCH 22/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 38e56530da..55578a3b23 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -236,7 +236,9 @@ def _forward( recompute_for_bwd, ) = non_tensor_args keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() - assert not keep_backward_unquantized, "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + assert ( + not keep_backward_unquantized + ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: From 84143580b8546cf0394a662538e4a898107fa503 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:10:10 -0800 Subject: [PATCH 23/45] Move interface changes to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 67 +++++++++++++++++-- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 6 +- .../pytorch/ops/basic/quantize.py | 2 +- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- transformer_engine/pytorch/quantization.py | 39 +++++++---- 11 files changed, 99 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 18577b0eb4..341f23972c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,11 @@ from pydantic.dataclasses import dataclass +def _default_quantize_backward() -> bool: + """Default backward quantization setting.""" + return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -188,6 +193,11 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. Delayed scaling + always quantizes backward; setting this to False is not supported. Notes ----- @@ -211,6 +221,8 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -223,7 +235,9 @@ def __repr__(self) -> str: f"amax_history_len={self.amax_history_len}, " f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -237,6 +251,10 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -249,6 +267,10 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -264,7 +286,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -291,12 +315,18 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -305,7 +335,9 @@ def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " - f"format={str(self.fp8_format).split('.')[1]}" + f"format={str(self.fp8_format).split('.')[1]}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -334,6 +366,10 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -386,7 +422,9 @@ def __repr__(self) -> str: f"fp8_gemm_dgrad={self.fp8_gemm_dgrad}, " f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " - f"fp8_mha={self.fp8_mha}" + f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" ) @@ -435,6 +473,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ # Configuration envvars @@ -450,6 +492,8 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -481,6 +525,8 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -512,12 +558,23 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" + quantize_forward : bool, default = True + Whether to quantize tensors in the forward pass. + quantize_backward : bool, default = True + Whether to quantize tensors in the backward pass. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = field(default_factory=_default_quantize_backward) def __repr__(self) -> str: - return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" + return ( + f"recipe_type={self.__class__.__name__}, " + f"qfactory={self.qfactory}, " + f"quantize_forward={self.quantize_forward}, " + f"quantize_backward={self.quantize_backward}" + ) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a469753bbf..6afdb5673a 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,7 +98,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 60c4e1d8b2..4173c76216 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,7 +141,7 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 55578a3b23..28b80d6a60 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,7 +235,7 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 535d2e75e5..76ff5dd1d4 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,7 +129,7 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index f2b8ba106e..d73fceeaf0 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,7 +332,9 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = FP8GlobalStateManager.keep_backward_unquantized() + keep_backward_unquantized = ( + FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) @@ -989,7 +991,7 @@ def op_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 7dd8f1a7ac..4c67cd8cce 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -60,7 +60,7 @@ def op_forward( quantize_backward = ( fp8_enabled and self._quantize_backward - and not FP8GlobalStateManager.keep_backward_unquantized() + and FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Quantize if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2458d4d072..80cb5647d7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -93,7 +93,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index efa543e555..cf29140a20 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -87,7 +87,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 2804534968..0caae13af9 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -66,7 +66,7 @@ def fuser_forward( grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() keep_backward_unquantized = ( - with_quantized_compute and FP8GlobalStateManager.keep_backward_unquantized() + with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) ) # Get extra input tensor for add operation diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index aab7ed2d1c..fb0553056a 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,6 +87,21 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) +def _validate_recipe_quantization_flags(recipe: Recipe) -> None: + """Validate forward/backward quantization flags on a recipe.""" + quantize_forward = getattr(recipe, "quantize_forward", True) + quantize_backward = getattr(recipe, "quantize_backward", True) + if not quantize_forward and quantize_backward: + raise ValueError( + "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + ) + if recipe.delayed() and not quantize_backward: + raise ValueError( + "Invalid recipe configuration: delayed scaling does not support " + "quantize_backward=False." + ) + + def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -430,15 +445,6 @@ def with_high_precision_init_val(cls) -> bool: """Should the high precision initial values be stored with FP8 parameters""" return cls.HIGH_PRECISION_INIT_VAL - @classmethod - def keep_backward_unquantized(cls) -> bool: - """Should backward skip FP8 quantization and use high precision""" - recipe = cls.get_fp8_recipe() - if recipe is not None and recipe.delayed(): - # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used - return False - return bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -851,16 +857,21 @@ def autocast( are reduced at the end of each training step. """ - if enabled: - check_recipe_support(recipe) + fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe + if enabled or calibrating: + _validate_recipe_quantization_flags(fp8_recipe) + quantize_forward = getattr(fp8_recipe, "quantize_forward", True) + effective_enabled = enabled and quantize_forward + if effective_enabled: + check_recipe_support(fp8_recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=enabled, + enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=recipe, + fp8_recipe=fp8_recipe, fp8_group=amax_reduction_group, _graph=_graph, ) @@ -868,7 +879,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From eecfcf87485d1adf6df336d63999465959eccc6e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:11:01 +0000 Subject: [PATCH 24/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_linear.py | 4 +++- transformer_engine/pytorch/module/layernorm_mlp.py | 4 +++- transformer_engine/pytorch/module/linear.py | 4 +++- transformer_engine/pytorch/ops/basic/basic_linear.py | 8 ++++---- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 ++-- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 ++-- .../pytorch/ops/fused/forward_linear_scale_add.py | 4 ++-- 8 files changed, 22 insertions(+), 14 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 6afdb5673a..8db7eab0e2 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,7 +98,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4173c76216..3016d41c5f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,7 +141,9 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 28b80d6a60..9148babff9 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,7 +235,9 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) assert ( not keep_backward_unquantized ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 76ff5dd1d4..c8feddf5af 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,7 +129,9 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = fp8 and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) if keep_backward_unquantized: # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used save_original_input = True diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index d73fceeaf0..307b2e1624 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,8 +332,8 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = ( - FP8GlobalStateManager.is_fp8_enabled() and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) columnwise_usage = weight_requires_grad and not keep_backward_unquantized input_quantizer = self.get_quantizer("forward", 0) @@ -990,8 +990,8 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 80cb5647d7..2bccabb306 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,8 +92,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index cf29140a20..03e3bff6f3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,8 +86,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get autocast dtype if needed diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 0caae13af9..8cebcec53a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,8 +65,8 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = ( - with_quantized_compute and (not FP8GlobalStateManager.get_fp8_recipe().quantize_backward) + keep_backward_unquantized = with_quantized_compute and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward ) # Get extra input tensor for add operation From bfb840ce226b6e841c5245beed3804aeb28f2c11 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:43:08 -0800 Subject: [PATCH 25/45] Move ub overrides to fwd Signed-off-by: Ziang Li --- .../pytorch/module/layernorm_linear.py | 14 ++++++++------ transformer_engine/pytorch/module/linear.py | 13 +++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3016d41c5f..f39fb45608 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -539,6 +539,14 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug + # keep_backward_unquantized overrides + if keep_backward_unquantized: + # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -610,12 +618,6 @@ def backward( keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c8feddf5af..3ed78e85da 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -493,6 +493,13 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False + # ------------------------------------------------------ # Cached state for backward pass is ready... # ------------------------------------------------------ @@ -545,12 +552,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized - if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None From 23f14eaa01a9ca504b8a516d630bc4a0339ff1b2 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:44:22 -0800 Subject: [PATCH 26/45] Remove duplication Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 341f23972c..1307302180 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -269,8 +269,6 @@ class Float8CurrentScaling(Recipe): fp8_mha: bool = False quantize_forward: bool = True quantize_backward: bool = field(default_factory=_default_quantize_backward) - quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." From ae25deea7ed47a36e2136d01809b7717deaa33d1 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 13:59:39 -0800 Subject: [PATCH 27/45] Simplify use_fp8_bwd logic in bwd Signed-off-by: Ziang Li --- .../pytorch/module/grouped_linear.py | 19 +++++++++---- .../pytorch/module/layernorm_linear.py | 25 ++++++++--------- transformer_engine/pytorch/module/linear.py | 28 +++++++++---------- 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 8db7eab0e2..f06df6b81b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -315,6 +315,14 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers + + # keep_backward_unquantized overrides + if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + ctx.ub_overlap_ag = False + ctx.ub_overlap_rs_dgrad = False + ctx.ub_bulk_dgrad = False + ctx.ub_bulk_wgrad = False # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -331,7 +339,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -347,7 +354,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) grad_output = [None] * ctx.num_gemms grad_biases = [None] * ctx.num_gemms - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: if ctx.use_bias: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) recipe = ctx.fp8_recipe @@ -401,7 +408,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: dgrad_gemm_use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): dgrad_gemm_use_split_accumulator = ( @@ -435,7 +442,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): wgrad_gemm_use_split_accumulator = ( @@ -463,7 +470,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: input_quantizer.set_usage(rowwise=False, columnwise=True) inputmats: list - if use_fp8_bwd and not ctx.debug: + if ctx.fp8 and not ctx.debug: inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -540,7 +547,7 @@ def handle_custom_ddp_from_mcore(weight, wgrad): if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() - and not use_fp8_bwd + and not ctx.fp8 ): grad_biases = [None] * ctx.num_gemms diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index f39fb45608..1ef8536e4f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -541,7 +541,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: - # Disable Userbuffers communication for backward pass when keep_backward_unquantized is True + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -617,7 +617,6 @@ def backward( origin_weight.main_grad = main_grad keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -655,7 +654,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -692,7 +691,7 @@ def backward( ln_out_total_work = None if ctx.ln_out_needs_gather: quantizer = None - if ctx.input_quantizer is not None and use_fp8_bwd: + if ctx.input_quantizer is not None and ctx.fp8: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -731,7 +730,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage) ): @@ -739,13 +738,13 @@ def backward( # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -769,7 +768,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -854,14 +853,14 @@ def backward( if ln_out_total_work is not None: ln_out_total_work.wait() ln_out_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -870,7 +869,7 @@ def backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -896,7 +895,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -904,7 +903,7 @@ def backward( ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3ed78e85da..a97ba398e0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -495,6 +495,7 @@ def forward( # keep_backward_unquantized overrides if keep_backward_unquantized: + ctx.fp8 = ctx.fp8 and not keep_backward_unquantized ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -551,7 +552,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_pop(f"{nvtx_label}.fsdp_gather") keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -592,7 +592,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and use_fp8_bwd: + if ctx.grad_output_quantizer is not None and ctx.fp8: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -611,7 +611,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and use_fp8_bwd + and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -641,7 +641,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass @@ -667,7 +667,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat = cast_if_needed(inputmat, ctx.activation_dtype) if ctx.backward_input_needs_gather: quantizer = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -709,7 +709,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - use_fp8_bwd + ctx.fp8 and ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -717,13 +717,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and use_fp8_bwd: + if ctx.grad_input_quantizer is not None and ctx.fp8: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -748,7 +748,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if use_fp8_bwd else None, + quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -797,7 +797,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -839,7 +839,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if use_fp8_bwd or ctx.debug: + if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -848,7 +848,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if use_fp8_bwd: + if ctx.fp8: recipe = ctx.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator @@ -874,7 +874,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if use_fp8_bwd else None), + "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) @@ -882,7 +882,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not use_fp8_bwd) else None), + "bias": (bias if (grad_bias is None and not ctx.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, From 491dd4413a05adbf7df947747a8170fc27369525 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:00:24 +0000 Subject: [PATCH 28/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f06df6b81b..2ac655657e 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -315,7 +315,7 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + # keep_backward_unquantized overrides if keep_backward_unquantized: ctx.fp8 = ctx.fp8 and not keep_backward_unquantized From 764ee6fb058df771f50d1fdb144f7da261a26832 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 14:28:06 -0800 Subject: [PATCH 29/45] Set grad quantizers to none if keep bwd unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 3 +++ .../pytorch/module/layernorm_linear.py | 11 +++++++---- transformer_engine/pytorch/module/linear.py | 13 ++++++++----- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2ac655657e..b32067056f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -323,6 +323,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1ef8536e4f..4de6afa38b 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -546,6 +546,9 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -654,7 +657,7 @@ def backward( # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -744,7 +747,7 @@ def backward( use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -768,7 +771,7 @@ def backward( grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -895,7 +898,7 @@ def backward( "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a97ba398e0..1fd2fcba8d 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -500,6 +500,10 @@ def forward( ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False ctx.ub_bulk_wgrad = False + ctx.grad_input_quantizer = None + ctx.grad_weight_quantizer = None + ctx.grad_output_quantizer = None + # ------------------------------------------------------ # Cached state for backward pass is ready... @@ -592,7 +596,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Configure quantizer for grad output tensor # Note: dgrad GEMM requires row-wise usage, wgrad GEMM # requires column-wise usage - if ctx.grad_output_quantizer is not None and ctx.fp8: + if ctx.grad_output_quantizer is not None: quantizer = ctx.grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) if ctx.ub_overlap_ag: @@ -611,7 +615,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], not ctx.use_bias and not ctx.requires_wgrad and ctx.grad_output_quantizer is not None - and ctx.fp8 ): ctx.grad_output_quantizer.set_usage(columnwise=False) @@ -723,7 +726,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator # Update grad input quantizer - if ctx.grad_input_quantizer is not None and ctx.fp8: + if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # Output buffers for Userbuffers reduce-scatter @@ -748,7 +751,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], grad_output, layout="NN", grad=True, - quantization_params=ctx.grad_input_quantizer if ctx.fp8 else None, + quantization_params=ctx.grad_input_quantizer, out=gemm_out, out_dtype=ctx.activation_dtype, use_split_accumulator=use_split_accumulator, @@ -874,7 +877,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], "out_dtype": ( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), - "quantization_params": (ctx.grad_weight_quantizer if ctx.fp8 else None), + "quantization_params": ctx.grad_weight_quantizer, "accumulate": ( accumulate_wgrad_into_param_main_grad if not getattr(weight, "overwrite_main_grad", False) From bbaa6c5ab7a389d226223fafb4f79d405a6fb855 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Feb 2026 22:28:55 +0000 Subject: [PATCH 30/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 1fd2fcba8d..3e8c4c146f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -503,7 +503,6 @@ def forward( ctx.grad_input_quantizer = None ctx.grad_weight_quantizer = None ctx.grad_output_quantizer = None - # ------------------------------------------------------ # Cached state for backward pass is ready... From 5e638e04add2723aa0b7c1a660cba41248dffbda Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Thu, 5 Feb 2026 17:28:04 -0800 Subject: [PATCH 31/45] Drop delayed scaling change Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/layernorm_linear.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4de6afa38b..26b14c2d8a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -527,11 +527,7 @@ def forward( ctx.requires_dgrad = inp_requires_grad ctx.normalization = normalization ctx.reduce_and_update_bwd_fp8_tensors = False - if ( - ctx.fp8 - and not ctx.keep_backward_unquantized - and requires_grad(inp, ln_weight, ln_bias, weight, bias) - ): + if ctx.fp8 and requires_grad(inp, ln_weight, ln_bias, weight, bias): _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): From 75079a4babef8ba302120900fd77e6b52853f425 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:29:24 -0800 Subject: [PATCH 32/45] Simplify env var logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1307302180..e76256cce3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,11 +11,6 @@ from pydantic.dataclasses import dataclass -def _default_quantize_backward() -> bool: - """Default backward quantization setting.""" - return not bool(int(os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0"))) - - class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -222,7 +217,7 @@ def scaling_factor_compute(amax: Tensor, fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -268,7 +263,7 @@ class Float8CurrentScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -324,7 +319,7 @@ class MXFP8BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." @@ -491,7 +486,7 @@ class NVFP4BlockScaling(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" @@ -567,7 +562,7 @@ class CustomRecipe(Recipe): fp8_dpa: bool = False fp8_mha: bool = False quantize_forward: bool = True - quantize_backward: bool = field(default_factory=_default_quantize_backward) + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __repr__(self) -> str: return ( From 0874804f9acbe9cde71bdb5765592756f08d8197 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:41:01 -0800 Subject: [PATCH 33/45] Move validation check to recipe Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 18 ++++++++++++++++++ transformer_engine/pytorch/quantization.py | 17 ----------------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index e76256cce3..b7cdbe818a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -221,6 +221,12 @@ def scaling_factor_compute(amax: Tensor, def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." + assert ( + not self.quantize_backward + ), "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( @@ -267,6 +273,9 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -323,6 +332,9 @@ class MXFP8BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -400,6 +412,9 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -491,6 +506,9 @@ class NVFP4BlockScaling(Recipe): def __post_init__(self) -> None: assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + assert not ( + not self.quantize_forward and self.quantize_backward + ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index fb0553056a..bbffe51eec 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -87,21 +87,6 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: ) -def _validate_recipe_quantization_flags(recipe: Recipe) -> None: - """Validate forward/backward quantization flags on a recipe.""" - quantize_forward = getattr(recipe, "quantize_forward", True) - quantize_backward = getattr(recipe, "quantize_backward", True) - if not quantize_forward and quantize_backward: - raise ValueError( - "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - ) - if recipe.delayed() and not quantize_backward: - raise ValueError( - "Invalid recipe configuration: delayed scaling does not support " - "quantize_backward=False." - ) - - def check_recipe_support(recipe: Recipe) -> None: """Check if the given recipe is supported.""" recipe_supported = True @@ -858,8 +843,6 @@ def autocast( """ fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - if enabled or calibrating: - _validate_recipe_quantization_flags(fp8_recipe) quantize_forward = getattr(fp8_recipe, "quantize_forward", True) effective_enabled = enabled and quantize_forward if effective_enabled: From 134757851e182abb113527e0b235f1aa7bf79553 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:55:28 -0800 Subject: [PATCH 34/45] Simplify effective_enabled Signed-off-by: Ziang Li --- transformer_engine/pytorch/quantization.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index bbffe51eec..00196c584f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -842,11 +842,9 @@ def autocast( are reduced at the end of each training step. """ - fp8_recipe = get_default_fp8_recipe() if recipe is None else recipe - quantize_forward = getattr(fp8_recipe, "quantize_forward", True) - effective_enabled = enabled and quantize_forward + effective_enabled = enabled and getattr(recipe, "quantize_forward", True) if effective_enabled: - check_recipe_support(fp8_recipe) + check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() @@ -854,7 +852,7 @@ def autocast( FP8GlobalStateManager.autocast_enter( enabled=effective_enabled, calibrating=calibrating, - fp8_recipe=fp8_recipe, + fp8_recipe=recipe, fp8_group=amax_reduction_group, _graph=_graph, ) From 2e81568fc419602872bc09deebabee733910d97a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 11:56:33 -0800 Subject: [PATCH 35/45] Fix inverted assertion logic Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b7cdbe818a..1c14e5e42c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -224,9 +224,7 @@ def __post_init__(self) -> None: assert not ( not self.quantize_forward and self.quantize_backward ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - assert ( - not self.quantize_backward - ), "Delayed scaling does not support quantize_backward=False." + assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." def __repr__(self) -> str: return ( From cf04de655b9a276aeb853297c86cf8fc062fa39a Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:33:38 -0800 Subject: [PATCH 36/45] Simplify changes under ops Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/basic_linear.py | 4 ---- transformer_engine/pytorch/ops/basic/quantize.py | 11 ++++++----- .../ops/fused/forward_linear_bias_activation.py | 7 ++----- .../pytorch/ops/fused/forward_linear_bias_add.py | 7 ++----- .../pytorch/ops/fused/forward_linear_scale_add.py | 7 ++----- 5 files changed, 12 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 307b2e1624..15a6815d2e 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -1020,11 +1020,7 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: saved_input = input_ if keep_backward_unquantized else x_local - if not weight_requires_grad: - saved_input = None saved_weight = self.weight if keep_backward_unquantized else w - if not input_requires_grad: - saved_weight = None if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 4c67cd8cce..9dcd33f9b3 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -57,11 +57,12 @@ def op_forward( # Check if FP8 is enabled fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() quantize_forward = fp8_enabled and self._quantize_forward - quantize_backward = ( - fp8_enabled - and self._quantize_backward - and FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + quantize_backward = fp8_enabled and self._quantize_backward + + # Recipe quantize overrides + if FP8GlobalStateManager.get_fp8_recipe() is not None: + quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 2bccabb306..860407904c 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -122,11 +122,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 03e3bff6f3..0729291d55 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -119,11 +119,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 8cebcec53a..dfdd11a231 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -100,11 +100,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = x_local - saved_weight = w - if keep_backward_unquantized: - saved_input = input_ if input_requires_grad else None - saved_weight = linear_op.weight if weight_requires_grad else None + saved_input = input_ if keep_backward_unquantized else x_local + saved_weight = linear_op.weight if keep_backward_unquantized else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) From edcb1f603c3182e137ceea2dd7b7644d345f6e17 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:34:39 +0000 Subject: [PATCH 37/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/ops/basic/quantize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 9dcd33f9b3..b2a36d1daa 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -61,8 +61,12 @@ def op_forward( # Recipe quantize overrides if FP8GlobalStateManager.get_fp8_recipe() is not None: - quantize_forward = quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward - quantize_backward = quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + quantize_forward = ( + quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward + ) + quantize_backward = ( + quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) # Quantize if needed out = input_ From 891fc7b3db0788cfd52d0e93e099eef62c7c9377 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 12:52:01 -0800 Subject: [PATCH 38/45] Simplify ctx.keep_backward_unquantized Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/module/grouped_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_linear.py | 4 +--- transformer_engine/pytorch/module/linear.py | 4 +--- 4 files changed, 4 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 48b02acb01..fe5be68034 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,8 +1135,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - use_fp8_bwd = ctx.fp8 and not keep_backward_unquantized + use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b32067056f..530e8c2075 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -341,7 +341,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights = saved_tensors[2 * N : 3 * N] biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: @@ -423,7 +422,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 26b14c2d8a..187fd70f92 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -615,8 +615,6 @@ def backward( if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: origin_weight.main_grad = main_grad - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -760,7 +758,7 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3e8c4c146f..7d960102ec 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -554,8 +554,6 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") - keep_backward_unquantized = getattr(ctx, "keep_backward_unquantized", False) - # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -743,7 +741,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if keep_backward_unquantized: + if ctx.keep_backward_unquantized: weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, From 5d20f771d5fae01580b374f8d4bb18c02f62e20e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Mon, 9 Feb 2026 15:07:48 -0800 Subject: [PATCH 39/45] Fix missing attribute Signed-off-by: Ziang Li --- transformer_engine/common/recipe/__init__.py | 2 ++ transformer_engine/pytorch/module/layernorm_mlp.py | 1 + 2 files changed, 3 insertions(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1c14e5e42c..46a19652f1 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -389,6 +389,8 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False + quantize_forward: bool = True + quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") def __post_init__(self) -> None: assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9148babff9..ae80694587 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -787,6 +787,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + ctx.keep_backward_unquantized = keep_backward_unquantized ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer From bbd22c701f4361837b7665d7101d1c72c95524f7 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 14:02:10 -0800 Subject: [PATCH 40/45] Add unit tests Signed-off-by: Ziang Li --- qa/L0_pytorch_unittest/test.sh | 1 + .../pytorch/test_keep_backward_unquantized.py | 701 ++++++++++++++++++ 2 files changed, 702 insertions(+) create mode 100644 tests/pytorch/test_keep_backward_unquantized.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index cd2d85c91c..7829620608 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -40,6 +40,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_KEEP_BACKWARD_UNQUANTIZED=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_keep_backward_unquantized.xml $TE_PATH/tests/pytorch/test_keep_backward_unquantized.py || test_fail "test_keep_backward_unquantized.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py new file mode 100644 index 0000000000..a5ef00e34c --- /dev/null +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -0,0 +1,701 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import os +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, +) + +from utils import quantization_tols, reset_rng_states + + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +# This file is intended to run in dedicated keep-backward-unquantized mode. +pytestmark = pytest.mark.skipif( + os.environ.get("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") != "1", + reason="Requires NVTE_KEEP_BACKWARD_UNQUANTIZED=1", +) + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((8, 4, 64), id="3d_m32_k64"), +] + + +def _make_recipe(recipe_name: str, quantize_backward: Optional[bool]) -> recipe.Recipe: + kwargs = {} + if quantize_backward is not None: + kwargs = {"quantize_forward": True, "quantize_backward": quantize_backward} + + if recipe_name == "fp8_current_scaling": + return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "mxfp8": + return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "fp8_block_scaling": + return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "nvfp4": + return recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **kwargs, + ) + + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _build_keep_backward_unquantized_recipe(recipe_name: str) -> recipe.Recipe: + fp8_recipe = _make_recipe(recipe_name, quantize_backward=None) + assert fp8_recipe.quantize_forward + assert not fp8_recipe.quantize_backward + return fp8_recipe + + +def _build_quantized_reference_recipe(recipe_name: str) -> recipe.Recipe: + return _make_recipe(recipe_name, quantize_backward=True) + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _fprop_tolerances(recipe_name: str) -> dict[str, float]: + if recipe_name == "mxfp8": + return quantization_tols("mxfp8") + if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): + return quantization_tols("fp8_current_scaling") + if recipe_name == "nvfp4": + return quantization_tols("nvfp4") + raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + bias: bool = False, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + ) + + +def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: + bias = getattr(module, "bias", None) + if bias is None or bias.grad is None: + return None + return bias.grad.detach().clone() + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + weight_grads = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + bias_grads: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + bias_grads.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + bias_grads.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), weight_grads, bias_grads + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + x2: Optional[torch.Tensor] = None, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + weight_grad = model[0].weight.grad.detach().clone() + bias_grad = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + bias_grad = model[0].bias.grad.detach().clone() + x2_grad = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +def test_keep_backward_unquantized_recipe_defaults(recipe_name: str): + _ = _build_keep_backward_unquantized_recipe(recipe_name) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "module_type", + ("linear", "layernorm_linear", "ops_linear"), +) +@pytest.mark.parametrize( + "input_shape,out_features", + _shape_test_cases, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, +): + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + dtype = torch.bfloat16 + in_features = input_shape[-1] + + module_quantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_keep_bwd_hp = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + module_unquantized_ref = _make_linear_like_module( + module_type, in_features, out_features, dtype, bias=use_bias + ) + + # Start all runs from identical parameters. + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( + module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + # Forward pass should still match quantized reference when only backward is unquantized. + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # Backward pass should match unquantized reference for dgrad and wgrad. + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if use_bias: + bgrad_keep = _extract_bias_grad(module_keep_bwd_hp) + bgrad_unquantized = _extract_bias_grad(module_unquantized_ref) + assert bgrad_keep is not None + assert bgrad_unquantized is not None + torch.testing.assert_close(bgrad_keep, bgrad_unquantized, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize( + "m_splits", + ([32, 32, 32, 32], [64, 0, 32, 32]), + ids=("uniform_splits", "with_empty_split"), +) +def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( + recipe_name: str, + use_bias: bool, + m_splits: list[int], +): + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_keep_bwd_hp = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + + _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, x, m_splits, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = _run_grouped_linear_single_step( + module_keep_bwd_hp, x, m_splits, dy, keep_bwd_hp_recipe + ) + _, dx_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = _run_grouped_linear_single_step( + module_unquantized_ref, x, m_splits, dy, None + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) + for test_dw, ref_dw in zip(dw_keep_bwd_hp, dw_unquantized_ref): + torch.testing.assert_close(test_dw, ref_dw, rtol=0, atol=0) + if use_bias: + for test_db, ref_db in zip(db_keep_bwd_hp, db_unquantized_ref): + assert test_db is not None + assert ref_db is not None + torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +def test_keep_backward_unquantized_fused_linear_paths( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + m = 32 + + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, model_quantized_ref, x1, dy, quantized_ref_recipe, x2=x2 + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, dx2_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = ( + _run_fused_single_step( + fused_pattern, + model_keep_bwd_hp, + x1, + dy, + keep_bwd_hp_recipe, + x2=x2, + ) + ) + _, dx1_unquantized_ref, dx2_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = ( + _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_unquantized_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) + if dx2_keep_bwd_hp is not None and dx2_unquantized_ref is not None: + torch.testing.assert_close(dx2_keep_bwd_hp, dx2_unquantized_ref, rtol=0, atol=0) + if db_keep_bwd_hp is not None and db_unquantized_ref is not None: + torch.testing.assert_close(db_keep_bwd_hp, db_unquantized_ref, rtol=0, atol=0) + + +@pytest.mark.parametrize( + "recipe_name", + _quantized_numerics_recipe_list, +) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], +): + # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + dtype = torch.bfloat16 + in_features = input_shape[-1] + out_features = 64 + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_keep_bwd_hp = _make_fused_model("bias_activation", in_features, out_features, dtype) + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", in_features, out_features, dtype, bias=True + ) + + _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) + _copy_named_parameters(model_keep_bwd_hp[0], linear_unquantized_ref) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + out_shape = x1.shape[:-1] + (out_features,) + dy = torch.randn(*out_shape, dtype=dtype, device="cuda") + + quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", model_quantized_ref, x1, dy, quantized_ref_recipe + ) + y_keep_bwd_hp, dx1_keep_bwd_hp, _, dw_keep_bwd_hp, db_keep_bwd_hp = _run_fused_single_step( + "bias_activation", model_keep_bwd_hp, x1, dy, keep_bwd_hp_recipe + ) + + # Ensure this test executes the fused path changed by the keep-bwd feature. + fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # keep-bwd mode should disable backward-activation+bias fusion, while quantized + # reference should still use it. + keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops + assert not any( + isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops + ) + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any( + isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops + ) + + torch.testing.assert_close( + y_keep_bwd_hp, + y_quantized_ref, + **_fprop_tolerances(recipe_name), + ) + + # In keep-backward-unquantized mode, backward should behave as high-precision linear backward + # given the ReLU mask induced by quantized forward activations. + dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) + _, dx1_expected, dw_expected = _run_single_step(linear_unquantized_ref, x1, dy_after_activation, None) + db_expected = _extract_bias_grad(linear_unquantized_ref) + assert db_keep_bwd_hp is not None + assert db_expected is not None + + torch.testing.assert_close(dx1_keep_bwd_hp, dx1_expected, rtol=0, atol=0) + torch.testing.assert_close(dw_keep_bwd_hp, dw_expected, rtol=0, atol=0) + torch.testing.assert_close(db_keep_bwd_hp, db_expected, rtol=0, atol=0) + + +def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): + reset_rng_states() + dtype = torch.bfloat16 + in_features = 64 + out_features = 64 + + module_quantization_disabled = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) + module_unquantized_ref = _make_linear_like_module("linear", in_features, out_features, dtype, bias=True) + _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) + + x = torch.randn(32, in_features, dtype=dtype, device="cuda") + dy = torch.randn(32, out_features, dtype=dtype, device="cuda") + + recipe_no_fwd_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + + y_test, dx_test, dw_test = _run_single_step( + module_quantization_disabled, x, dy, recipe_no_fwd_quant + ) + y_ref, dx_ref, dw_ref = _run_single_step(module_unquantized_ref, x, dy, None) + + torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) + torch.testing.assert_close(dw_test, dw_ref, rtol=0, atol=0) + bgrad_test = _extract_bias_grad(module_quantization_disabled) + bgrad_ref = _extract_bias_grad(module_unquantized_ref) + assert bgrad_test is not None + assert bgrad_ref is not None + torch.testing.assert_close(bgrad_test, bgrad_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_quantize_op_respects_recipe_overrides(): + reset_rng_states() + dtype = torch.bfloat16 + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + + recipe_no_quant = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.E4M3, + quantize_forward=False, + quantize_backward=False, + ) + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, recipe_no_quant) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, None) + + torch.testing.assert_close(y_override, y_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_override, dx_ref, rtol=0, atol=0) + + +def test_keep_backward_unquantized_is_invalid_for_delayed_scaling(): + with pytest.raises( + (AssertionError, ValueError), + match="Delayed scaling does not support quantize_backward=False", + ): + _ = recipe.DelayedScaling() + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_keep_backward_unquantized_not_implemented_for_layernorm_mlp(): + reset_rng_states() + layer = te.LayerNormMLP( + hidden_size=64, + ffn_hidden_size=64, + params_dtype=torch.bfloat16, + bias=False, + device="cuda", + ) + x = torch.randn(32, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) + keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe("fp8_current_scaling") + + with pytest.raises( + AssertionError, match="NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + ): + with te.autocast(enabled=True, recipe=keep_bwd_hp_recipe): + _ = layer(x) From 7d74aa3ec2e59b76a386eac67dcfa3cbdc865019 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 14:03:02 -0800 Subject: [PATCH 41/45] Fix bias errors in unit test Signed-off-by: Ziang Li --- transformer_engine/pytorch/ops/basic/bias.py | 8 +++++++- .../pytorch/ops/fused/backward_activation_bias.py | 5 +++-- .../pytorch/ops/fused/forward_linear_bias_activation.py | 4 +++- .../pytorch/ops/fused/forward_linear_bias_add.py | 4 +++- 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index d580f84866..8bcd84b441 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,6 +10,7 @@ import torch import transformer_engine_torch as tex +from ...quantization import FP8GlobalStateManager from ..op import BasicOperation, OperationContext from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -123,7 +124,12 @@ def op_forward( b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) if ctx.requires_grad: - ctx.grad_input_quantizer = prev_op_grad_output_quantizer + keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( + not FP8GlobalStateManager.get_fp8_recipe().quantize_backward + ) + ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else prev_op_grad_output_quantizer + ) return x + b diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 4ab082d32b..395a9dbd67 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -104,8 +104,9 @@ def fuse_backward_ops( """ - # Check if recipe supports bias activation fusion - if recipe is None: + # Check if recipe supports bias activation fusion. + # keep-backward-unquantized mode should use unfused backward ops. + if recipe is None or not recipe.quantize_backward: return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 860407904c..42f459a41e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -138,7 +138,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 0729291d55..75d58fd5cc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -135,7 +135,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + bias_op_ctx.grad_input_quantizer = ( + None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + ) return output, [() for _ in range(len(self.basic_ops))] From 5bc3a576e8708b4b69b87c1b913c176a60afc09b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:03:50 +0000 Subject: [PATCH 42/45] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/test_keep_backward_unquantized.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py index a5ef00e34c..fe11bfcd3a 100644 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -214,7 +214,9 @@ def _run_grouped_linear_single_step( y = module(x_run, m_splits) y.backward(dy) assert x_run.grad is not None - weight_grads = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + weight_grads = [ + getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms) + ] bias_grads: list[Optional[torch.Tensor]] = [] for i in range(module.num_gemms): if module.use_bias: @@ -257,7 +259,9 @@ def _run_fused_single_step( dy: torch.Tensor, fp8_recipe: Optional[recipe.Recipe], x2: Optional[torch.Tensor] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[ + torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor] +]: model.zero_grad(set_to_none=True) x1_run = x1.detach().clone().requires_grad_(True) x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None @@ -276,7 +280,9 @@ def _run_fused_single_step( bias_grad = None if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: bias_grad = model[0].bias.grad.detach().clone() - x2_grad = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + x2_grad = ( + x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + ) return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad @@ -355,7 +361,9 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe ) - _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step(module_unquantized_ref, x, dy, None) + _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step( + module_unquantized_ref, x, dy, None + ) # Forward pass should still match quantized reference when only backward is unquantized. torch.testing.assert_close( @@ -458,6 +466,7 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un assert ref_db is not None torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) + @pytest.mark.parametrize( "recipe_name", _quantized_numerics_recipe_list, @@ -589,13 +598,9 @@ def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_b # keep-bwd mode should disable backward-activation+bias fusion, while quantized # reference should still use it. keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops - assert not any( - isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops - ) + assert not any(isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops) quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops - assert any( - isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops - ) + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) torch.testing.assert_close( y_keep_bwd_hp, @@ -606,7 +611,9 @@ def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_b # In keep-backward-unquantized mode, backward should behave as high-precision linear backward # given the ReLU mask induced by quantized forward activations. dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) - _, dx1_expected, dw_expected = _run_single_step(linear_unquantized_ref, x1, dy_after_activation, None) + _, dx1_expected, dw_expected = _run_single_step( + linear_unquantized_ref, x1, dy_after_activation, None + ) db_expected = _extract_bias_grad(linear_unquantized_ref) assert db_keep_bwd_hp is not None assert db_expected is not None @@ -625,7 +632,9 @@ def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): module_quantization_disabled = _make_linear_like_module( "linear", in_features, out_features, dtype, bias=True ) - module_unquantized_ref = _make_linear_like_module("linear", in_features, out_features, dtype, bias=True) + module_unquantized_ref = _make_linear_like_module( + "linear", in_features, out_features, dtype, bias=True + ) _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) x = torch.randn(32, in_features, dtype=dtype, device="cuda") From 255589e0f74ab52c1abe765f675e1631800920b0 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 10 Feb 2026 15:22:40 -0800 Subject: [PATCH 43/45] Add more shapes to unit test Signed-off-by: Ziang Li --- .../pytorch/test_keep_backward_unquantized.py | 54 +++++++++++++++++-- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py index fe11bfcd3a..f5c3339a71 100644 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ b/tests/pytorch/test_keep_backward_unquantized.py @@ -5,6 +5,7 @@ from __future__ import annotations from contextlib import nullcontext +import math import os from typing import Optional @@ -64,7 +65,9 @@ ] _shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), ] @@ -166,6 +169,46 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + # TE Linear / LayerNormLinear FP8 kernels require FP8-GEMM-compatible dimensions. + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + # te_ops.Linear (fusible ops) has stricter constraints for some block-scaled recipes. + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + # Grouped GEMM paths enforce additional split-alignment constraints for block-scaled recipes. + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + def _run_single_step( module: torch.nn.Module, x: torch.Tensor, @@ -333,6 +376,7 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads ): reset_rng_states() _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) dtype = torch.bfloat16 in_features = input_shape[-1] @@ -390,8 +434,8 @@ def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads @pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) @pytest.mark.parametrize( "m_splits", - ([32, 32, 32, 32], [64, 0, 32, 32]), - ids=("uniform_splits", "with_empty_split"), + ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), + ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), ) def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( recipe_name: str, @@ -400,6 +444,7 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un ): if recipe_name == "nvfp4": pytest.skip("NVFP4 not supported for grouped linear") + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) reset_rng_states() dtype = torch.bfloat16 @@ -478,10 +523,12 @@ def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_un ("scale_add", ForwardLinearScaleAdd), ), ) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) def test_keep_backward_unquantized_fused_linear_paths( recipe_name: str, fused_pattern: str, expected_fused_op: type, + m: int, ): # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") @@ -490,8 +537,7 @@ def test_keep_backward_unquantized_fused_linear_paths( dtype = torch.bfloat16 in_features = 64 out_features = 64 - m = 32 - + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) From c915bc3eadba9f6ce35c2171e688fb80aa14b0da Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Tue, 24 Feb 2026 15:10:49 -0800 Subject: [PATCH 44/45] Refator interface to `NVTE_BACKWARD_MODE=default|unquant|dequant` Signed-off-by: Ziang Li --- qa/L0_pytorch_unittest/test.sh | 2 +- tests/pytorch/test_backward_mode.py | 1446 +++++++++++++++++ .../pytorch/test_keep_backward_unquantized.py | 756 --------- transformer_engine/common/recipe/__init__.py | 129 +- transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/grouped_linear.py | 74 +- .../pytorch/module/layernorm_linear.py | 47 +- .../pytorch/module/layernorm_mlp.py | 13 +- transformer_engine/pytorch/module/linear.py | 51 +- .../pytorch/ops/basic/basic_linear.py | 44 +- transformer_engine/pytorch/ops/basic/bias.py | 11 +- .../pytorch/ops/basic/quantize.py | 12 +- .../ops/fused/backward_activation_bias.py | 4 +- .../fused/forward_linear_bias_activation.py | 23 +- .../ops/fused/forward_linear_bias_add.py | 19 +- .../ops/fused/forward_linear_scale_add.py | 17 +- .../ops/fused/userbuffers_forward_linear.py | 13 + transformer_engine/pytorch/ops/fuser.py | 14 +- transformer_engine/pytorch/quantization.py | 7 +- 19 files changed, 1749 insertions(+), 935 deletions(-) create mode 100644 tests/pytorch/test_backward_mode.py delete mode 100644 tests/pytorch/test_keep_backward_unquantized.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 7829620608..faf9a1d03d 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -40,7 +40,7 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -NVTE_KEEP_BACKWARD_UNQUANTIZED=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_keep_backward_unquantized.xml $TE_PATH/tests/pytorch/test_keep_backward_unquantized.py || test_fail "test_keep_backward_unquantized.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_mode.xml $TE_PATH/tests/pytorch/test_backward_mode.py || test_fail "test_backward_mode.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" diff --git a/tests/pytorch/test_backward_mode.py b/tests/pytorch/test_backward_mode.py new file mode 100644 index 0000000000..300d860496 --- /dev/null +++ b/tests/pytorch/test_backward_mode.py @@ -0,0 +1,1446 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from __future__ import annotations + +from contextlib import nullcontext +import math +from typing import Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.common import recipe +from transformer_engine.pytorch.cpp_extensions import general_gemm, layernorm_bwd +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.ops.fused import ( + BackwardActivationBias, + ForwardLinearBiasActivation, + ForwardLinearBiasAdd, + ForwardLinearScaleAdd, + UserbuffersForwardLinear, +) +from transformer_engine.pytorch.quantized_tensor import restore_from_saved + +from utils import quantization_tols, reset_rng_states + + +# -------------------------- +# Mode and capability config +# -------------------------- + +_NON_QUANT_BACKWARD_MODES = ("unquant", "dequant") + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) +bf16_available, reason_for_no_bf16 = te.is_bf16_available(return_reason=True) + +# Broad dtype coverage for modules touched by this change. +_core_dtypes = [torch.float16, torch.float32] +if bf16_available: + _core_dtypes.insert(1, torch.bfloat16) + +# Fused GEMM+bias+activation requires FP16/BF16 output. +_fused_dtypes = [torch.float16] +if bf16_available: + _fused_dtypes.append(torch.bfloat16) + + +@pytest.fixture(autouse=True) +def _reset_global_fp8_state(): + """Avoid global FP8-state leakage between parametrized cases.""" + yield + FP8GlobalStateManager.reset() + + +@pytest.fixture(params=_NON_QUANT_BACKWARD_MODES, ids=lambda mode: f"mode_{mode}") +def backward_mode(request: pytest.FixtureRequest) -> str: + """Backward mode under test.""" + return request.param + + +# -------------------------- +# Shared helpers +# -------------------------- + + +def _assert_exact(test: torch.Tensor, ref: torch.Tensor) -> None: + torch.testing.assert_close(test, ref, rtol=0, atol=0) + + +def _assert_forward_matches_quantized_ref( + test: torch.Tensor, + ref: torch.Tensor, + recipe_name: str, +) -> None: + torch.testing.assert_close(test, ref, **_fprop_tolerances(recipe_name)) + + +def _restore_saved_operands(output: torch.Tensor) -> list[Optional[torch.Tensor]]: + if output.grad_fn is None: + raise RuntimeError("Output tensor has no grad_fn; cannot inspect saved operands") + if not hasattr(output.grad_fn, "tensor_objects"): + raise RuntimeError("grad_fn does not expose tensor_objects for saved operand restoration") + return restore_from_saved(output.grad_fn.tensor_objects, list(output.grad_fn.saved_tensors)) + + +def _extract_linear_saved_operands( + saved_operands: list[Optional[torch.Tensor]], + *, + context: str, +) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + if len(saved_operands) < 2: + raise RuntimeError( + f"Insufficient saved operands for {context} dequant reference " + f"(got {len(saved_operands)}, expected at least 2)." + ) + return saved_operands[0], saved_operands[1] + + +def _dequantize_saved_operand( + saved_operand: Optional[torch.Tensor], + dtype: torch.dtype, +) -> torch.Tensor: + if saved_operand is None: + raise RuntimeError("Expected saved operand but got None") + # In dequant mode we must consume the fprop-saved quantized payload directly. + # If row-wise payload is missing, the tensor was retargeted to a transpose-only + # layout and no longer represents the original fprop operand. + if ( + not isinstance(saved_operand, torch.Tensor) + and hasattr(saved_operand, "_rowwise_data") + and getattr(saved_operand, "_rowwise_data") is None + ): + raise RuntimeError( + "Saved dequant operand lost row-wise fprop payload (likely usage retarget)." + ) + if isinstance(saved_operand, torch.Tensor): + return saved_operand.to(dtype) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved operand type: {type(saved_operand)}") + return saved_operand.dequantize(dtype=dtype) + + +def _assert_saved_quantized_operand_uses_rowwise_only( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> None: + if saved_operand is None: + raise RuntimeError(f"Expected quantized saved {name} operand but got None") + if isinstance(saved_operand, torch.Tensor): + raise RuntimeError( + f"Dequant reference expects quantized saved {name} operand, got torch.Tensor." + ) + if not hasattr(saved_operand, "dequantize"): + raise RuntimeError(f"Unsupported saved {name} operand type: {type(saved_operand)}") + if hasattr(saved_operand, "_rowwise_data") and getattr(saved_operand, "_rowwise_data") is None: + raise RuntimeError( + f"Saved dequant {name} operand lost row-wise fprop payload (likely usage retarget)." + ) + if ( + hasattr(saved_operand, "_columnwise_data") + and getattr(saved_operand, "_columnwise_data") is not None + ): + raise RuntimeError( + f"Saved dequant {name} operand unexpectedly carries column-wise payload." + ) + + +def _snapshot_saved_quantized_operand_layout( + saved_operand: Optional[torch.Tensor], + *, + name: str, +) -> dict[str, object]: + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + rowwise_present = None + columnwise_present = None + rowwise_obj_id = None + if hasattr(saved_operand, "_rowwise_data"): + rowwise_data = getattr(saved_operand, "_rowwise_data") + rowwise_present = rowwise_data is not None + if rowwise_data is not None: + rowwise_obj_id = id(rowwise_data) + if hasattr(saved_operand, "_columnwise_data"): + columnwise_present = getattr(saved_operand, "_columnwise_data") is not None + return { + "name": name, + "saved_operand": saved_operand, + "rowwise_present": rowwise_present, + "columnwise_present": columnwise_present, + "rowwise_obj_id": rowwise_obj_id, + } + + +def _assert_saved_quantized_operand_layout_unchanged(snapshot: dict[str, object]) -> None: + name = snapshot.get("name") + if not isinstance(name, str): + raise RuntimeError(f"Invalid saved operand snapshot name: {name!r}") + saved_operand = snapshot.get("saved_operand") + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + + rowwise_present = snapshot.get("rowwise_present") + if isinstance(rowwise_present, bool): + rowwise_data_now = getattr(saved_operand, "_rowwise_data", None) + rowwise_now = rowwise_data_now is not None + if rowwise_now != rowwise_present: + raise RuntimeError( + f"Saved dequant {name} operand row-wise payload presence changed " + f"from {rowwise_present} to {rowwise_now}." + ) + # Guard against hidden requantization that swaps in a new row-wise payload. + rowwise_obj_id = snapshot.get("rowwise_obj_id") + if ( + isinstance(rowwise_obj_id, int) + and rowwise_now + and id(rowwise_data_now) != rowwise_obj_id + ): + raise RuntimeError( + f"Saved dequant {name} operand row-wise payload identity changed " + "(likely rewritten/requantized)." + ) + + columnwise_present = snapshot.get("columnwise_present") + if isinstance(columnwise_present, bool): + columnwise_now = getattr(saved_operand, "_columnwise_data", None) is not None + if columnwise_now != columnwise_present: + raise RuntimeError( + f"Saved dequant {name} operand column-wise payload presence changed " + f"from {columnwise_present} to {columnwise_now}." + ) + + +def _snapshot_layout_invariants( + guard_operands: list[tuple[str, Optional[torch.Tensor]]], +) -> list[dict[str, object]]: + """Capture saved-operand layout invariants before backward runs.""" + return [ + _snapshot_saved_quantized_operand_layout(saved_operand, name=name) + for name, saved_operand in guard_operands + ] + + +def _assert_layout_invariants_unchanged(layout_invariants: list[dict[str, object]]) -> None: + """Validate saved-operand layout invariants after backward runs.""" + for layout_invariant in layout_invariants: + _assert_saved_quantized_operand_layout_unchanged(layout_invariant) + + +def _raise_if_ref_failed(ref_exc: Optional[Exception]) -> None: + """Re-raise deferred reference exceptions after layout checks.""" + if ref_exc is not None: + raise ref_exc + + +def _compute_linear_backward_reference_from_saved_operands( + saved_input: Optional[torch.Tensor], + saved_weight: Optional[torch.Tensor], + dy: torch.Tensor, + *, + dequant_dtype: torch.dtype, + out_dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Dequant reference path: + # 1) use the exact operands saved by quantized forward, + # 2) dequantize them to the active high-precision compute dtype, + # 3) run backward GEMMs in high precision and compare exactly. + for name, saved_operand in (("input", saved_input), ("weight", saved_weight)): + _assert_saved_quantized_operand_uses_rowwise_only(saved_operand, name=name) + dy_mat = dy.reshape(-1, dy.shape[-1]) + + # Empty-token chunks can happen in grouped/fused paths. Reference should be zeros. + if dy_mat.shape[0] == 0: + out_features = dy_mat.shape[-1] + if saved_input is None: + raise RuntimeError("Expected saved input operand for empty-chunk dequant reference.") + in_features = saved_input.size(-1) + dx_ref = torch.zeros(*dy.shape[:-1], in_features, dtype=out_dtype, device=dy.device) + dw_ref = torch.zeros(out_features, in_features, dtype=out_dtype, device=dy.device) + db_ref = torch.zeros(out_features, dtype=out_dtype, device=dy.device) + return dx_ref, dw_ref, db_ref + + x_ref_full = _dequantize_saved_operand(saved_input, dequant_dtype) + x_ref = x_ref_full.reshape(-1, x_ref_full.shape[-1]) + w_ref = _dequantize_saved_operand(saved_weight, dequant_dtype) + + dx_ref_2d, *_ = general_gemm( + w_ref, + dy_mat, + out_dtype=out_dtype, + layout="NN", + grad=True, + ) + # Derive db from the same GEMM primitive used by runtime wgrad. This avoids + # tiny reduction-order drift vs. a standalone dy.sum() path in FP32 cases. + db_seed = torch.empty(dy_mat.shape[-1], dtype=out_dtype, device=dy_mat.device) + dw_ref, db_ref, *_ = general_gemm( + x_ref, + dy_mat, + out_dtype=out_dtype, + layout="NT", + grad=True, + bias=db_seed, + ) + if db_ref is None: + db_ref = dy_mat.sum(dim=0).to(out_dtype) + dx_ref = dx_ref_2d.view(*dy.shape[:-1], dx_ref_2d.shape[-1]) + return dx_ref, dw_ref, db_ref + + +_quantized_numerics_recipe_list = [ + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + id="Float8CurrentScaling", + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + id="MXFP8BlockScaling", + ), + pytest.param( + "fp8_block_scaling", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, + reason=reason_for_no_fp8_block_scaling, + ), + id="Float8BlockScaling", + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4BlockScaling", + ), +] + +_shape_test_cases = [ + pytest.param((1, 64), 64, id="2d_m1_k64_n64"), + pytest.param((32, 64), 64, id="2d_m32_k64_n64"), + pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), + pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), + pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), +] + +_bias_activation_shape_cases = [ + pytest.param((32, 64), id="2d_m32_k64"), + pytest.param((8, 4, 64), id="3d_m32_k64"), +] + + +def _make_recipe(recipe_name: str, *, backward_mode: str) -> recipe.Recipe: + kwargs = {"backward_mode": backward_mode} + if recipe_name == "fp8_current_scaling": + return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "mxfp8": + return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "fp8_block_scaling": + return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) + if recipe_name == "nvfp4": + return recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **kwargs, + ) + raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") + + +def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: + src_params = dict(src_module.named_parameters()) + with torch.no_grad(): + for name, dst_param in dst_module.named_parameters(): + if name not in src_params: + raise RuntimeError(f"Parameter {name} missing in source module") + dst_param.copy_(src_params[name]) + + +def _fprop_tolerances(recipe_name: str) -> dict[str, float]: + if recipe_name == "mxfp8": + return quantization_tols("mxfp8") + if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): + return quantization_tols("fp8_current_scaling") + if recipe_name == "nvfp4": + return quantization_tols("nvfp4") + raise ValueError(f"Unsupported recipe for backward-mode test: {recipe_name}") + + +def _maybe_skip_recipe_dtype(recipe_name: str, dtype: torch.dtype, backward_mode: str) -> None: + if dtype == torch.bfloat16 and not bf16_available: + pytest.skip(reason_for_no_bf16) + if recipe_name == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 is only supported with BF16 in this test") + + +def _make_linear_like_module( + module_type: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + bias: bool, +) -> torch.nn.Module: + if module_type == "linear": + return te.Linear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "layernorm_linear": + return te.LayerNormLinear( + in_features, + out_features, + bias=bias, + params_dtype=dtype, + device="cuda", + ) + if module_type == "ops_linear": + return te_ops.Linear( + in_features, + out_features, + bias=bias, + dtype=dtype, + device="cuda", + ) + raise ValueError(f"Unsupported module type: {module_type}") + + +def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: + if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": + pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + + +def _maybe_skip_unsupported_recipe_shape( + recipe_name: str, + input_shape: tuple[int, ...], + module_type: str, +) -> None: + flat_first_dim = math.prod(input_shape[:-1]) + last_dim = input_shape[-1] + + if module_type in ("linear", "layernorm_linear"): + if flat_first_dim % 8 != 0 or last_dim % 16 != 0: + pytest.skip( + "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " + "and shape[-1] divisible by 16." + ) + return + + if module_type == "ops_linear": + if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): + pytest.skip( + "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." + ) + if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + pytest.skip( + "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." + ) + + +def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: + non_empty_splits = [m for m in m_splits if m > 0] + if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): + pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") + if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): + pytest.skip( + "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." + ) + + +def _run_single_step( + module: torch.nn.Module, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + y.backward(dy) + assert x_run.grad is not None + assert module.weight.grad is not None + bgrad = _extract_bias_grad(module) + return ( + y.detach().clone(), + x_run.grad.detach().clone(), + module.weight.grad.detach().clone(), + bgrad, + ) + + +def _run_single_step_with_saved_operands( + module: torch.nn.Module, + x: torch.Tensor, + fp8_recipe: recipe.Recipe, +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[Optional[torch.Tensor]], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = module(x_run) + if isinstance(y, tuple): + y = y[0] + saved_operands = _restore_saved_operands(y) + return y, x_run, saved_operands + + +def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: + bias = getattr(module, "bias", None) + if bias is None or bias.grad is None: + return None + return bias.grad.detach().clone() + + +def _run_grouped_linear_single_step( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = module(x_run, m_splits) + y.backward(dy) + assert x_run.grad is not None + + dw = [getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms)] + db: list[Optional[torch.Tensor]] = [] + for i in range(module.num_gemms): + if module.use_bias: + db.append(getattr(module, f"bias{i}").grad.detach().clone()) + else: + db.append(None) + return y.detach().clone(), x_run.grad.detach().clone(), dw, db + + +def _run_grouped_linear_step_with_saved_operands( + module: te.GroupedLinear, + x: torch.Tensor, + m_splits: list[int], + fp8_recipe: recipe.Recipe, +) -> tuple[ + torch.Tensor, + torch.Tensor, + list[Optional[torch.Tensor]], +]: + module.zero_grad(set_to_none=True) + x_run = x.detach().clone().requires_grad_(True) + with te.autocast(enabled=True, recipe=fp8_recipe): + y = module(x_run, m_splits) + saved_operands = _restore_saved_operands(y) + return y, x_run, saved_operands + + +def _make_fused_model( + pattern: str, + in_features: int, + out_features: int, + dtype: torch.dtype, + *, + scale: float = 0.5, +) -> te_ops.Sequential: + if pattern == "bias_activation": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.ReLU(), + ) + if pattern == "bias_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), + te_ops.AddExtraInput(in_place=True), + ) + if pattern == "scale_add": + return te_ops.Sequential( + te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + ) + raise ValueError(f"Unsupported fused test pattern: {pattern}") + + +def _run_fused_single_step( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], + *, + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + torch.Tensor, + Optional[torch.Tensor], +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + y.backward(dy) + assert x1_run.grad is not None + + dw = model[0].weight.grad.detach().clone() + db = None + if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: + db = model[0].bias.grad.detach().clone() + dx2 = x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None + return y.detach().clone(), x1_run.grad.detach().clone(), dx2, dw, db + + +def _run_fused_single_step_with_saved_operands( + pattern: str, + model: te_ops.Sequential, + x1: torch.Tensor, + fp8_recipe: recipe.Recipe, + *, + x2: Optional[torch.Tensor] = None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + Optional[torch.Tensor], + list[Optional[torch.Tensor]], +]: + model.zero_grad(set_to_none=True) + x1_run = x1.detach().clone().requires_grad_(True) + x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None + with te.autocast(enabled=True, recipe=fp8_recipe): + if pattern in ("bias_add", "scale_add"): + assert x2_run is not None + y = model(x1_run, x2_run) + else: + y = model(x1_run) + saved_operands = _restore_saved_operands(y) + return y, x1_run, x2_run, saved_operands + + +def _run_quantize_op_single_step( + model: te_ops.Sequential, + x: torch.Tensor, + dy: torch.Tensor, + fp8_recipe: Optional[recipe.Recipe], +) -> tuple[torch.Tensor, torch.Tensor]: + x_run = x.detach().clone().requires_grad_(True) + autocast_ctx = ( + te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() + ) + with autocast_ctx: + y = model(x_run) + y.backward(dy) + assert x_run.grad is not None + return y.detach().clone(), x_run.grad.detach().clone() + + +def _make_userbuffers_fuser_for_mode_switch_test( + *, + dtype: torch.dtype, +) -> tuple[object, torch.Tensor, list[tuple[()]]]: + """Build a Userbuffers-eligible fuser and representative inputs.""" + in_features = 64 + out_features = 64 + linear = te_ops.BasicLinear( + in_features, + out_features, + device="cuda", + dtype=dtype, + userbuffers_options={"comm_name": "qkv"}, + ) + linear.tensor_parallel_mode = "column" + linear.tensor_parallel_size = 2 + linear.sequence_parallel = True + bias = te_ops.Bias(out_features, device="cuda", dtype=dtype) + model = te_ops.Sequential(linear, bias) + model._module_groups = model._make_module_groups( + model._modules.values() + ) # pylint: disable=protected-access + fuser = model._module_groups[0] + x = torch.randn(32, in_features, dtype=dtype, device="cuda", requires_grad=True) + extra_inputs = [() for _ in range(fuser._num_basic_ops)] # pylint: disable=protected-access + return fuser, x, extra_inputs + + +def _has_userbuffers_forward_linear(fuser: object) -> bool: + return any( + isinstance(op, UserbuffersForwardLinear) for op, _ in fuser._forward_ops + ) # pylint: disable=protected-access + + +# -------------------------- +# Tests +# -------------------------- + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +def test_backward_mode_recipe_matches_requested_mode( + recipe_name: str, + backward_mode: str, +) -> None: + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + quant_recipe = _make_recipe(recipe_name, backward_mode="default") + assert mode_recipe.backward_mode == backward_mode + assert quant_recipe.backward_mode == "default" + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("module_type", ("linear", "layernorm_linear", "ops_linear")) +@pytest.mark.parametrize("input_shape,out_features", _shape_test_cases) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_linear_like_backward_mode_matches_reference( + recipe_name: str, + module_type: str, + input_shape: tuple[int, ...], + out_features: int, + use_bias: bool, + dtype: torch.dtype, + backward_mode: str, +) -> None: + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) + _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) + + in_features = input_shape[-1] + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + module_quantized_ref = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + module_bwd_mode = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(module_quantized_ref, module_bwd_mode) + + output_shape = input_shape[:-1] + (out_features,) + x = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*output_shape, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) + if backward_mode == "unquant": + # Unquant reference path: compare against a plain high-precision backward run + # (no fp8/autocast), starting from the same params and inputs. + module_unquantized_ref = _make_linear_like_module( + module_type, + in_features, + out_features, + dtype, + bias=use_bias, + ) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + y_bwd_mode, dx_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_single_step( + module_bwd_mode, + x, + dy, + mode_recipe, + ) + _, dx_ref, dw_ref, db_ref = _run_single_step( + module_unquantized_ref, + x, + dy, + None, + ) + else: + # Dequant reference path: capture saved forward operands from the real dequant-mode + # execution, then rebuild backward reference from those saved operands. + y_bwd_mode, x_bwd_mode, saved_operands = _run_single_step_with_saved_operands( + module_bwd_mode, x, mode_recipe + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + + dx_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + if module_type == "layernorm_linear": + # LayerNormLinear dequant reference: + # 1) Compute d(ln_out), dw, db from linear backward with saved operands. + # 2) Compute exact dx via layernorm_bwd with saved norm statistics. + # _LayerNormLinear forward saves operands as: + # [inputmat, weightmat, origin_weight, bias, ln_weight, ln_out, mu, rsigma, ...] + if len(saved_operands) < 8: + raise RuntimeError( + "Insufficient saved operands for layernorm_linear dequant reference " + f"(got {len(saved_operands)}, expected at least 8)." + ) + saved_input = saved_operands[0] + saved_weight = saved_operands[1] + saved_ln_weight = saved_operands[4] + saved_ln_out = saved_operands[5] + saved_mu = saved_operands[6] + saved_rsigma = saved_operands[7] + guard_operands.extend( + [ + ("layernorm_linear_ln_out", saved_ln_out), + ("layernorm_linear_weight", saved_weight), + ] + ) + d_ln_out_ref, dw_ref, db_ref = ( + _compute_linear_backward_reference_from_saved_operands( + saved_ln_out, + saved_weight, + dy, + dequant_dtype=dtype, + out_dtype=dtype, + ) + ) + input_ref = _dequantize_saved_operand(saved_input, dtype) + input_ref_2d = input_ref.reshape(-1, input_ref.shape[-1]) + ln_weight_ref = _dequantize_saved_operand(saved_ln_weight, dtype).view(-1) + if saved_mu is None or saved_rsigma is None: + raise RuntimeError("Missing LayerNorm statistics in saved operands") + if not isinstance(saved_mu, torch.Tensor) or not isinstance( + saved_rsigma, torch.Tensor + ): + raise RuntimeError("LayerNorm statistics must be Tensor objects") + dx_ref, *_ = layernorm_bwd( + d_ln_out_ref.reshape(input_ref_2d.shape), + input_ref_2d, + saved_mu, + saved_rsigma, + ln_weight_ref, + module_bwd_mode.bwd_ln_sm_margin, + module_bwd_mode.zero_centered_gamma, + ) + dx_ref = dx_ref.view_as(x_bwd_mode) + else: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context=f"{module_type}", + ) + guard_operands.extend( + [ + (f"{module_type}_input", saved_input), + (f"{module_type}_weight", saved_weight), + ] + ) + dx_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy, + dequant_dtype=dtype, + out_dtype=dtype, + ) + if module_type == "ops_linear" and use_bias: + # te_ops bias grad is reduced by the Bias op from incoming dy. + db_ref = dy.reshape(-1, dy.shape[-1]).sum(dim=0).to(dtype) + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x_bwd_mode.grad is not None + assert module_bwd_mode.weight.grad is not None + dx_bwd_mode = x_bwd_mode.grad.detach().clone() + dw_bwd_mode = module_bwd_mode.weight.grad.detach().clone() + db_bwd_mode = _extract_bias_grad(module_bwd_mode) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx_ref is not None and dw_ref is not None and db_ref is not None + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx_bwd_mode, dx_ref) + _assert_exact(dw_bwd_mode, dw_ref) + if use_bias: + assert db_bwd_mode is not None + assert db_ref is not None + _assert_exact(db_bwd_mode, db_ref) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) +@pytest.mark.parametrize( + "m_splits", + ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), + ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), +) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_grouped_linear_backward_mode_matches_reference( + recipe_name: str, + use_bias: bool, + m_splits: list[int], + dtype: torch.dtype, + backward_mode: str, +) -> None: + if recipe_name == "nvfp4": + pytest.skip("NVFP4 not supported for grouped linear") + + reset_rng_states() + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) + + in_features = 64 + out_features = 64 + num_gemms = len(m_splits) + num_tokens = sum(m_splits) + + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + module_quantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + module_bwd_mode = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + _copy_named_parameters(module_quantized_ref, module_bwd_mode) + + x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") + dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( + module_quantized_ref, + x, + m_splits, + dy, + quantized_ref_recipe, + ) + if backward_mode == "unquant": + # Unquant reference path: grouped module in plain high precision. + module_unquantized_ref = te.GroupedLinear( + num_gemms, + in_features, + out_features, + bias=use_bias, + params_dtype=dtype, + device="cuda", + ) + _copy_named_parameters(module_quantized_ref, module_unquantized_ref) + y_bwd_mode, dx_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_grouped_linear_single_step( + module_bwd_mode, + x, + m_splits, + dy, + mode_recipe, + ) + _, dx_ref, dw_ref, db_ref = _run_grouped_linear_single_step( + module_unquantized_ref, + x, + m_splits, + dy, + None, + ) + else: + # Dequant reference path for grouped GEMMs: + # each GEMM restores its own saved input/weight pair and computes its own ref grads. + y_bwd_mode, x_bwd_mode, saved_operands = _run_grouped_linear_step_with_saved_operands( + module_bwd_mode, x, m_splits, mode_recipe + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + + dx_ref: Optional[torch.Tensor] = None + dw_ref: list[torch.Tensor] = [] + db_ref: list[Optional[torch.Tensor]] = [] + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + if len(saved_operands) < 2 * num_gemms: + raise RuntimeError( + "Insufficient saved operands for GroupedLinear dequant reference " + f"(got {len(saved_operands)}, expected at least {2 * num_gemms})." + ) + + saved_inputs = saved_operands[:num_gemms] + saved_weights = saved_operands[num_gemms : 2 * num_gemms] + for i, (saved_input, saved_weight) in enumerate(zip(saved_inputs, saved_weights)): + guard_operands.extend( + [ + (f"grouped_input{i}", saved_input), + (f"grouped_weight{i}", saved_weight), + ] + ) + dy_chunks = torch.split(dy, m_splits) + + dx_chunks = [] + dw_ref = [] + db_ref = [] + for dy_chunk, saved_input, saved_weight in zip(dy_chunks, saved_inputs, saved_weights): + dx_i, dw_i, db_i = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_chunk, + dequant_dtype=dtype, + out_dtype=dtype, + ) + dx_chunks.append(dx_i) + dw_ref.append(dw_i) + db_ref.append(db_i if use_bias else None) + dx_ref = torch.cat(dx_chunks, dim=0) + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x_bwd_mode.grad is not None + dx_bwd_mode = x_bwd_mode.grad.detach().clone() + dw_bwd_mode = [ + getattr(module_bwd_mode, f"weight{i}").grad.detach().clone() + for i in range(module_bwd_mode.num_gemms) + ] + db_bwd_mode = [] + for i in range(module_bwd_mode.num_gemms): + if module_bwd_mode.use_bias: + db_bwd_mode.append(getattr(module_bwd_mode, f"bias{i}").grad.detach().clone()) + else: + db_bwd_mode.append(None) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx_ref is not None + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx_bwd_mode, dx_ref) + for test_dw, ref_dw in zip(dw_bwd_mode, dw_ref): + _assert_exact(test_dw, ref_dw) + if use_bias: + for test_db, ref_db_i in zip(db_bwd_mode, db_ref): + assert test_db is not None + assert ref_db_i is not None + _assert_exact(test_db, ref_db_i) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize( + "fused_pattern,expected_fused_op", + ( + ("bias_add", ForwardLinearBiasAdd), + ("scale_add", ForwardLinearScaleAdd), + ), +) +@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_fused_linear_paths_match_backward_mode_reference( + recipe_name: str, + fused_pattern: str, + expected_fused_op: type, + m: int, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + _maybe_skip_unsupported_recipe_shape(recipe_name, (m, 64), "ops_linear") + + reset_rng_states() + in_features = 64 + out_features = 64 + + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + model_bwd_mode = _make_fused_model(fused_pattern, in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_bwd_mode) + + x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") + x2 = None + if fused_pattern in ("bias_add", "scale_add"): + x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") + dy = torch.randn(m, out_features, dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + fused_pattern, + model_quantized_ref, + x1, + dy, + quantized_ref_recipe, + x2=x2, + ) + + if backward_mode == "unquant": + # Unquant reference path: replay the same fused model structure in plain + # high precision and compare backward outputs exactly. + model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_unquantized_ref) + + y_bwd_mode, dx1_bwd_mode, dx2_bwd_mode, dw_bwd_mode, db_bwd_mode = _run_fused_single_step( + fused_pattern, + model_bwd_mode, + x1, + dy, + mode_recipe, + x2=x2, + ) + _, dx1_ref, dx2_ref, dw_ref, db_ref = _run_fused_single_step( + fused_pattern, + model_unquantized_ref, + x1, + dy, + None, + x2=x2, + ) + else: + # Dequant reference path: compute backward reference from saved quantized + # linear operands (with branch-specific dy handling for fused epilogues). + y_bwd_mode, x1_bwd_mode, x2_bwd_mode_ref, saved_operands = ( + _run_fused_single_step_with_saved_operands( + fused_pattern, + model_bwd_mode, + x1, + mode_recipe, + x2=x2, + ) + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + dx1_ref: Optional[torch.Tensor] = None + dx2_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context=f"fused_{fused_pattern}", + ) + guard_operands.extend( + [ + (f"fused_{fused_pattern}_input", saved_input), + (f"fused_{fused_pattern}_weight", saved_weight), + ] + ) + dy_for_linear = dy * 0.5 if fused_pattern == "scale_add" else dy + dx1_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_for_linear, + dequant_dtype=dtype, + out_dtype=dtype, + ) + dx2_ref = dy if x2 is not None else None + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x1_bwd_mode.grad is not None + dx1_bwd_mode = x1_bwd_mode.grad.detach().clone() + dx2_bwd_mode = ( + x2_bwd_mode_ref.grad.detach().clone() + if x2_bwd_mode_ref is not None and x2_bwd_mode_ref.grad is not None + else None + ) + dw_bwd_mode = model_bwd_mode[0].weight.grad.detach().clone() + db_bwd_mode = None + if ( + getattr(model_bwd_mode[0], "bias", None) is not None + and model_bwd_mode[0].bias.grad is not None + ): + db_bwd_mode = model_bwd_mode[0].bias.grad.detach().clone() + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx1_ref is not None and dw_ref is not None + + fused_ops = model_bwd_mode._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], expected_fused_op) + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx1_bwd_mode, dx1_ref) + _assert_exact(dw_bwd_mode, dw_ref) + if dx2_bwd_mode is not None and dx2_ref is not None: + _assert_exact(dx2_bwd_mode, dx2_ref) + if db_bwd_mode is not None and db_ref is not None: + _assert_exact(db_bwd_mode, db_ref) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_fused_bias_activation_matches_masked_linear_backward( + recipe_name: str, + input_shape: tuple[int, ...], + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + + reset_rng_states() + in_features = input_shape[-1] + out_features = 64 + + quantized_ref_recipe = _make_recipe(recipe_name, backward_mode="default") + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) + model_bwd_mode = _make_fused_model("bias_activation", in_features, out_features, dtype) + _copy_named_parameters(model_quantized_ref, model_bwd_mode) + + x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") + dy = torch.randn(*((*x1.shape[:-1], out_features)), dtype=dtype, device="cuda") + + y_quantized_ref, _, _, _, _ = _run_fused_single_step( + "bias_activation", + model_quantized_ref, + x1, + dy, + quantized_ref_recipe, + ) + + if backward_mode == "unquant": + # Unquant reference path: build a plain linear reference and apply the + # same activation mask (from quantized forward output) before backward. + linear_unquantized_ref = _make_linear_like_module( + "ops_linear", + in_features, + out_features, + dtype, + bias=True, + ) + _copy_named_parameters(model_bwd_mode[0], linear_unquantized_ref) + + y_bwd_mode, dx1_bwd_mode, _, dw_bwd_mode, db_bwd_mode = _run_fused_single_step( + "bias_activation", + model_bwd_mode, + x1, + dy, + mode_recipe, + ) + dy_after_activation = dy * (y_bwd_mode > 0).to(dy.dtype) + _, dx1_ref, dw_ref, db_ref = _run_single_step( + linear_unquantized_ref, + x1, + dy_after_activation, + None, + ) + else: + # Dequant reference path: restore saved linear operands from fused forward, + # apply the same activation mask, then run linear backward reference. + y_bwd_mode, x1_bwd_mode, _, saved_operands = _run_fused_single_step_with_saved_operands( + "bias_activation", + model_bwd_mode, + x1, + mode_recipe, + ) + y_bwd_mode_detached = y_bwd_mode.detach().clone() + dy_after_activation = dy * (y_bwd_mode > 0).to(dy.dtype) + dx1_ref: Optional[torch.Tensor] = None + dw_ref: Optional[torch.Tensor] = None + db_ref: Optional[torch.Tensor] = None + layout_invariants: list[dict[str, object]] = [] + guard_operands: list[tuple[str, Optional[torch.Tensor]]] = [] + ref_exc: Optional[Exception] = None + try: + saved_input, saved_weight = _extract_linear_saved_operands( + saved_operands, + context="fused_bias_activation", + ) + guard_operands.extend( + [ + ("fused_bias_activation_input", saved_input), + ("fused_bias_activation_weight", saved_weight), + ] + ) + dx1_ref, dw_ref, db_ref = _compute_linear_backward_reference_from_saved_operands( + saved_input, + saved_weight, + dy_after_activation, + dequant_dtype=dtype, + out_dtype=dtype, + ) + except Exception as exc: # pylint: disable=broad-exception-caught + ref_exc = exc + + layout_invariants = _snapshot_layout_invariants(guard_operands) + + y_bwd_mode.backward(dy) + assert x1_bwd_mode.grad is not None + dx1_bwd_mode = x1_bwd_mode.grad.detach().clone() + dw_bwd_mode = model_bwd_mode[0].weight.grad.detach().clone() + db_bwd_mode = ( + model_bwd_mode[0].bias.grad.detach().clone() + if model_bwd_mode[0].bias.grad is not None + else None + ) + y_bwd_mode = y_bwd_mode_detached + + _assert_layout_invariants_unchanged(layout_invariants) + _raise_if_ref_failed(ref_exc) + assert dx1_ref is not None and dw_ref is not None and db_ref is not None + + fused_ops = model_bwd_mode._module_groups[0]._forward_ops + assert len(fused_ops) >= 1 + assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) + + # In unquant/dequant modes, backward-activation+bias fusion should be disabled. + bwd_mode_backward_ops = model_bwd_mode._module_groups[0]._backward_ops + assert not any(isinstance(op, BackwardActivationBias) for op, _ in bwd_mode_backward_ops) + + # Quantized reference should still use fused backward path. + quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops + assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) + + _assert_forward_matches_quantized_ref(y_bwd_mode, y_quantized_ref, recipe_name) + _assert_exact(dx1_bwd_mode, dx1_ref) + _assert_exact(dw_bwd_mode, dw_ref) + assert db_bwd_mode is not None + assert db_ref is not None + _assert_exact(db_bwd_mode, db_ref) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_operation_fuser_rebuilds_userbuffers_fusion_on_backward_mode_switch( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, + monkeypatch: pytest.MonkeyPatch, +) -> None: + # Simulate a distributed setup to exercise Userbuffers fusion eligibility + # without launching a multi-rank job. + monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) + monkeypatch.setattr(torch.distributed, "get_world_size", lambda *_args, **_kwargs: 2) + + # Use a mutable recipe holder so we can switch fusion behavior on the same + # fuser object and verify that the cached fusion plan is refreshed. + current_recipe = {"value": _make_recipe(recipe_name, backward_mode="default")} + monkeypatch.setattr(FP8GlobalStateManager, "get_fp8_recipe", lambda: current_recipe["value"]) + + reset_rng_states() + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + fuser, x, extra_inputs = _make_userbuffers_fuser_for_mode_switch_test(dtype=dtype) + + quant_recipe = _make_recipe(recipe_name, backward_mode="default") + fuser.maybe_fuse_ops( + is_grad_enabled=True, + recipe=quant_recipe, + input_=x, + extra_inputs=extra_inputs, + ) + assert _has_userbuffers_forward_linear(fuser) + + non_quant_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + current_recipe["value"] = non_quant_recipe + fuser.maybe_fuse_ops( + is_grad_enabled=True, + recipe=non_quant_recipe, + input_=x, + extra_inputs=extra_inputs, + ) + assert not _has_userbuffers_forward_linear(fuser) + + +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _core_dtypes, ids=str) +def test_quantize_op_respects_backward_mode( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") + reset_rng_states() + + x = torch.randn(32, 64, dtype=dtype, device="cuda") + dy = torch.randn(32, 64, dtype=dtype, device="cuda") + + model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) + model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=False)) + + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, mode_recipe) + y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, mode_recipe) + + _assert_exact(y_override, y_ref) + _assert_exact(dx_override, dx_ref) + + +def test_delayed_scaling_rejects_non_quant_backward_mode(backward_mode: str) -> None: + with pytest.raises( + (AssertionError, ValueError), + match="Delayed scaling only supports backward_mode=default", + ): + _ = recipe.DelayedScaling(backward_mode=backward_mode) + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.parametrize("recipe_name", _quantized_numerics_recipe_list) +@pytest.mark.parametrize("dtype", _fused_dtypes, ids=str) +def test_layernorm_mlp_not_implemented_for_unquantized_backward_mode( + recipe_name: str, + dtype: torch.dtype, + backward_mode: str, +) -> None: + _maybe_skip_recipe_dtype(recipe_name, dtype, backward_mode) + reset_rng_states() + + layer = te.LayerNormMLP( + hidden_size=64, + ffn_hidden_size=64, + params_dtype=dtype, + bias=False, + device="cuda", + ) + x = torch.randn(32, 64, dtype=dtype, device="cuda", requires_grad=True) + mode_recipe = _make_recipe(recipe_name, backward_mode=backward_mode) + + with pytest.raises( + AssertionError, + match="NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP", + ): + with te.autocast(enabled=True, recipe=mode_recipe): + _ = layer(x) diff --git a/tests/pytorch/test_keep_backward_unquantized.py b/tests/pytorch/test_keep_backward_unquantized.py deleted file mode 100644 index f5c3339a71..0000000000 --- a/tests/pytorch/test_keep_backward_unquantized.py +++ /dev/null @@ -1,756 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -from __future__ import annotations - -from contextlib import nullcontext -import math -import os -from typing import Optional - -import pytest -import torch - -import transformer_engine.pytorch as te -import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common import recipe -from transformer_engine.pytorch.ops.fused import ( - BackwardActivationBias, - ForwardLinearBiasActivation, - ForwardLinearBiasAdd, - ForwardLinearScaleAdd, -) - -from utils import quantization_tols, reset_rng_states - - -fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) -mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( - return_reason=True -) -nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) - -# This file is intended to run in dedicated keep-backward-unquantized mode. -pytestmark = pytest.mark.skipif( - os.environ.get("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") != "1", - reason="Requires NVTE_KEEP_BACKWARD_UNQUANTIZED=1", -) - - -_quantized_numerics_recipe_list = [ - pytest.param( - "fp8_current_scaling", - marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), - id="Float8CurrentScaling", - ), - pytest.param( - "mxfp8", - marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), - id="MXFP8BlockScaling", - ), - pytest.param( - "fp8_block_scaling", - marks=pytest.mark.skipif( - not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling - ), - id="Float8BlockScaling", - ), - pytest.param( - "nvfp4", - marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), - id="NVFP4BlockScaling", - ), -] - -_shape_test_cases = [ - pytest.param((1, 64), 64, id="2d_m1_k64_n64"), - pytest.param((32, 64), 64, id="2d_m32_k64_n64"), - pytest.param((32, 1, 64), 64, id="3d_m32_s1_k64_n64"), - pytest.param((8, 4, 64), 128, id="3d_m32_k64_n128"), - pytest.param((16, 2, 128), 64, id="3d_m32_k128_n64"), -] - -_bias_activation_shape_cases = [ - pytest.param((32, 64), id="2d_m32_k64"), - pytest.param((8, 4, 64), id="3d_m32_k64"), -] - - -def _make_recipe(recipe_name: str, quantize_backward: Optional[bool]) -> recipe.Recipe: - kwargs = {} - if quantize_backward is not None: - kwargs = {"quantize_forward": True, "quantize_backward": quantize_backward} - - if recipe_name == "fp8_current_scaling": - return recipe.Float8CurrentScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "mxfp8": - return recipe.MXFP8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "fp8_block_scaling": - return recipe.Float8BlockScaling(fp8_format=recipe.Format.E4M3, **kwargs) - if recipe_name == "nvfp4": - return recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - **kwargs, - ) - - raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") - - -def _build_keep_backward_unquantized_recipe(recipe_name: str) -> recipe.Recipe: - fp8_recipe = _make_recipe(recipe_name, quantize_backward=None) - assert fp8_recipe.quantize_forward - assert not fp8_recipe.quantize_backward - return fp8_recipe - - -def _build_quantized_reference_recipe(recipe_name: str) -> recipe.Recipe: - return _make_recipe(recipe_name, quantize_backward=True) - - -def _copy_named_parameters(src_module: torch.nn.Module, dst_module: torch.nn.Module) -> None: - src_params = dict(src_module.named_parameters()) - with torch.no_grad(): - for name, dst_param in dst_module.named_parameters(): - if name not in src_params: - raise RuntimeError(f"Parameter {name} missing in source module") - dst_param.copy_(src_params[name]) - - -def _fprop_tolerances(recipe_name: str) -> dict[str, float]: - if recipe_name == "mxfp8": - return quantization_tols("mxfp8") - if recipe_name in ("fp8_current_scaling", "fp8_block_scaling"): - return quantization_tols("fp8_current_scaling") - if recipe_name == "nvfp4": - return quantization_tols("nvfp4") - raise ValueError(f"Unsupported recipe for keep-backward-unquantized test: {recipe_name}") - - -def _make_linear_like_module( - module_type: str, - in_features: int, - out_features: int, - dtype: torch.dtype, - bias: bool = False, -) -> torch.nn.Module: - if module_type == "linear": - return te.Linear( - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - ) - if module_type == "layernorm_linear": - return te.LayerNormLinear( - in_features, - out_features, - bias=bias, - params_dtype=dtype, - device="cuda", - ) - if module_type == "ops_linear": - return te_ops.Linear( - in_features, - out_features, - bias=bias, - dtype=dtype, - device="cuda", - ) - raise ValueError(f"Unsupported module type: {module_type}") - - -def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: - if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": - pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") - - -def _maybe_skip_unsupported_recipe_shape( - recipe_name: str, - input_shape: tuple[int, ...], - module_type: str, -) -> None: - flat_first_dim = math.prod(input_shape[:-1]) - last_dim = input_shape[-1] - - # TE Linear / LayerNormLinear FP8 kernels require FP8-GEMM-compatible dimensions. - if module_type in ("linear", "layernorm_linear"): - if flat_first_dim % 8 != 0 or last_dim % 16 != 0: - pytest.skip( - "Linear/LayerNormLinear FP8 execution requires prod(shape[:-1]) divisible by 8 " - "and shape[-1] divisible by 16." - ) - return - - # te_ops.Linear (fusible ops) has stricter constraints for some block-scaled recipes. - if module_type == "ops_linear": - if recipe_name == "mxfp8" and (flat_first_dim % 32 != 0 or last_dim % 32 != 0): - pytest.skip( - "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." - ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): - pytest.skip( - "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." - ) - - -def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int]) -> None: - # Grouped GEMM paths enforce additional split-alignment constraints for block-scaled recipes. - non_empty_splits = [m for m in m_splits if m > 0] - if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): - pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "fp8_block_scaling" and any(m % 4 != 0 for m in non_empty_splits): - pytest.skip( - "GroupedLinear + Float8BlockScaling requires each non-empty m_split divisible by 4." - ) - - -def _run_single_step( - module: torch.nn.Module, - x: torch.Tensor, - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - module.zero_grad(set_to_none=True) - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - y = module(x_run) - if isinstance(y, tuple): - y = y[0] - y.backward(dy) - assert x_run.grad is not None - assert module.weight.grad is not None - return ( - y.detach().clone(), - x_run.grad.detach().clone(), - module.weight.grad.detach().clone(), - ) - - -def _extract_bias_grad(module: torch.nn.Module) -> Optional[torch.Tensor]: - bias = getattr(module, "bias", None) - if bias is None or bias.grad is None: - return None - return bias.grad.detach().clone() - - -def _run_grouped_linear_single_step( - module: te.GroupedLinear, - x: torch.Tensor, - m_splits: list[int], - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], -) -> tuple[torch.Tensor, torch.Tensor, list[torch.Tensor], list[Optional[torch.Tensor]]]: - module.zero_grad(set_to_none=True) - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - y = module(x_run, m_splits) - y.backward(dy) - assert x_run.grad is not None - weight_grads = [ - getattr(module, f"weight{i}").grad.detach().clone() for i in range(module.num_gemms) - ] - bias_grads: list[Optional[torch.Tensor]] = [] - for i in range(module.num_gemms): - if module.use_bias: - bias_grads.append(getattr(module, f"bias{i}").grad.detach().clone()) - else: - bias_grads.append(None) - return y.detach().clone(), x_run.grad.detach().clone(), weight_grads, bias_grads - - -def _make_fused_model( - pattern: str, - in_features: int, - out_features: int, - dtype: torch.dtype, - scale: float = 0.5, -) -> te_ops.Sequential: - if pattern == "bias_activation": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), - te_ops.ReLU(), - ) - if pattern == "bias_add": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=True, device="cuda", dtype=dtype), - te_ops.AddExtraInput(in_place=True), - ) - if pattern == "scale_add": - return te_ops.Sequential( - te_ops.Linear(in_features, out_features, bias=False, device="cuda", dtype=dtype), - te_ops.ConstantScale(scale), - te_ops.AddExtraInput(in_place=True), - ) - raise ValueError(f"Unsupported fused test pattern: {pattern}") - - -def _run_fused_single_step( - pattern: str, - model: te_ops.Sequential, - x1: torch.Tensor, - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], - x2: Optional[torch.Tensor] = None, -) -> tuple[ - torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor, Optional[torch.Tensor] -]: - model.zero_grad(set_to_none=True) - x1_run = x1.detach().clone().requires_grad_(True) - x2_run = x2.detach().clone().requires_grad_(True) if x2 is not None else None - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - if pattern in ("bias_add", "scale_add"): - assert x2_run is not None - y = model(x1_run, x2_run) - else: - y = model(x1_run) - y.backward(dy) - assert x1_run.grad is not None - weight_grad = model[0].weight.grad.detach().clone() - bias_grad = None - if getattr(model[0], "bias", None) is not None and model[0].bias.grad is not None: - bias_grad = model[0].bias.grad.detach().clone() - x2_grad = ( - x2_run.grad.detach().clone() if x2_run is not None and x2_run.grad is not None else None - ) - return y.detach().clone(), x1_run.grad.detach().clone(), x2_grad, weight_grad, bias_grad - - -def _run_quantize_op_single_step( - model: te_ops.Sequential, - x: torch.Tensor, - dy: torch.Tensor, - fp8_recipe: Optional[recipe.Recipe], -) -> tuple[torch.Tensor, torch.Tensor]: - x_run = x.detach().clone().requires_grad_(True) - autocast_ctx = ( - te.autocast(enabled=True, recipe=fp8_recipe) if fp8_recipe is not None else nullcontext() - ) - with autocast_ctx: - y = model(x_run) - y.backward(dy) - assert x_run.grad is not None - return y.detach().clone(), x_run.grad.detach().clone() - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -def test_keep_backward_unquantized_recipe_defaults(recipe_name: str): - _ = _build_keep_backward_unquantized_recipe(recipe_name) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize( - "module_type", - ("linear", "layernorm_linear", "ops_linear"), -) -@pytest.mark.parametrize( - "input_shape,out_features", - _shape_test_cases, -) -@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) -def test_keep_backward_unquantized_matches_quantized_fprop_and_unquantized_grads( - recipe_name: str, - module_type: str, - input_shape: tuple[int, ...], - out_features: int, - use_bias: bool, -): - reset_rng_states() - _maybe_skip_unsupported_recipe_module_combo(recipe_name, module_type) - _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) - dtype = torch.bfloat16 - in_features = input_shape[-1] - - module_quantized_ref = _make_linear_like_module( - module_type, in_features, out_features, dtype, bias=use_bias - ) - module_keep_bwd_hp = _make_linear_like_module( - module_type, in_features, out_features, dtype, bias=use_bias - ) - module_unquantized_ref = _make_linear_like_module( - module_type, in_features, out_features, dtype, bias=use_bias - ) - - # Start all runs from identical parameters. - _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) - _copy_named_parameters(module_quantized_ref, module_unquantized_ref) - - output_shape = input_shape[:-1] + (out_features,) - x = torch.randn(*input_shape, dtype=dtype, device="cuda") - dy = torch.randn(*output_shape, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _ = _run_single_step(module_quantized_ref, x, dy, quantized_ref_recipe) - y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp = _run_single_step( - module_keep_bwd_hp, x, dy, keep_bwd_hp_recipe - ) - _, dx_unquantized_ref, dw_unquantized_ref = _run_single_step( - module_unquantized_ref, x, dy, None - ) - - # Forward pass should still match quantized reference when only backward is unquantized. - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - - # Backward pass should match unquantized reference for dgrad and wgrad. - torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) - torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) - if use_bias: - bgrad_keep = _extract_bias_grad(module_keep_bwd_hp) - bgrad_unquantized = _extract_bias_grad(module_unquantized_ref) - assert bgrad_keep is not None - assert bgrad_unquantized is not None - torch.testing.assert_close(bgrad_keep, bgrad_unquantized, rtol=0, atol=0) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize("use_bias", (False, True), ids=("no_bias", "bias")) -@pytest.mark.parametrize( - "m_splits", - ([32, 32, 32, 32], [64, 0, 32, 32], [1, 31, 0, 96]), - ids=("uniform_splits", "with_empty_split", "small_and_empty_splits"), -) -def test_keep_backward_unquantized_grouped_linear_matches_quantized_fprop_and_unquantized_grads( - recipe_name: str, - use_bias: bool, - m_splits: list[int], -): - if recipe_name == "nvfp4": - pytest.skip("NVFP4 not supported for grouped linear") - _maybe_skip_unsupported_grouped_splits(recipe_name, m_splits) - - reset_rng_states() - dtype = torch.bfloat16 - in_features = 64 - out_features = 64 - num_gemms = len(m_splits) - num_tokens = sum(m_splits) - - module_quantized_ref = te.GroupedLinear( - num_gemms, - in_features, - out_features, - bias=use_bias, - params_dtype=dtype, - device="cuda", - ) - module_keep_bwd_hp = te.GroupedLinear( - num_gemms, - in_features, - out_features, - bias=use_bias, - params_dtype=dtype, - device="cuda", - ) - module_unquantized_ref = te.GroupedLinear( - num_gemms, - in_features, - out_features, - bias=use_bias, - params_dtype=dtype, - device="cuda", - ) - - _copy_named_parameters(module_quantized_ref, module_keep_bwd_hp) - _copy_named_parameters(module_quantized_ref, module_unquantized_ref) - - x = torch.randn(num_tokens, in_features, dtype=dtype, device="cuda") - dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _, _ = _run_grouped_linear_single_step( - module_quantized_ref, x, m_splits, dy, quantized_ref_recipe - ) - y_keep_bwd_hp, dx_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = _run_grouped_linear_single_step( - module_keep_bwd_hp, x, m_splits, dy, keep_bwd_hp_recipe - ) - _, dx_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = _run_grouped_linear_single_step( - module_unquantized_ref, x, m_splits, dy, None - ) - - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - torch.testing.assert_close(dx_keep_bwd_hp, dx_unquantized_ref, rtol=0, atol=0) - for test_dw, ref_dw in zip(dw_keep_bwd_hp, dw_unquantized_ref): - torch.testing.assert_close(test_dw, ref_dw, rtol=0, atol=0) - if use_bias: - for test_db, ref_db in zip(db_keep_bwd_hp, db_unquantized_ref): - assert test_db is not None - assert ref_db is not None - torch.testing.assert_close(test_db, ref_db, rtol=0, atol=0) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize( - "fused_pattern,expected_fused_op", - ( - ("bias_add", ForwardLinearBiasAdd), - ("scale_add", ForwardLinearScaleAdd), - ), -) -@pytest.mark.parametrize("m", (1, 32), ids=("m1", "m32")) -def test_keep_backward_unquantized_fused_linear_paths( - recipe_name: str, - fused_pattern: str, - expected_fused_op: type, - m: int, -): - # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. - _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - - reset_rng_states() - dtype = torch.bfloat16 - in_features = 64 - out_features = 64 - _maybe_skip_unsupported_recipe_shape(recipe_name, (m, in_features), "ops_linear") - model_quantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) - model_keep_bwd_hp = _make_fused_model(fused_pattern, in_features, out_features, dtype) - model_unquantized_ref = _make_fused_model(fused_pattern, in_features, out_features, dtype) - - _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) - _copy_named_parameters(model_quantized_ref, model_unquantized_ref) - - x1 = torch.randn(m, in_features, dtype=dtype, device="cuda") - x2 = None - if fused_pattern in ("bias_add", "scale_add"): - x2 = torch.randn(m, out_features, dtype=dtype, device="cuda") - dy = torch.randn(m, out_features, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _, _, _ = _run_fused_single_step( - fused_pattern, model_quantized_ref, x1, dy, quantized_ref_recipe, x2=x2 - ) - y_keep_bwd_hp, dx1_keep_bwd_hp, dx2_keep_bwd_hp, dw_keep_bwd_hp, db_keep_bwd_hp = ( - _run_fused_single_step( - fused_pattern, - model_keep_bwd_hp, - x1, - dy, - keep_bwd_hp_recipe, - x2=x2, - ) - ) - _, dx1_unquantized_ref, dx2_unquantized_ref, dw_unquantized_ref, db_unquantized_ref = ( - _run_fused_single_step( - fused_pattern, - model_unquantized_ref, - x1, - dy, - None, - x2=x2, - ) - ) - - # Ensure this test executes the fused path changed by the keep-bwd feature. - fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops - assert len(fused_ops) >= 1 - assert isinstance(fused_ops[0][0], expected_fused_op) - - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - torch.testing.assert_close(dx1_keep_bwd_hp, dx1_unquantized_ref, rtol=0, atol=0) - torch.testing.assert_close(dw_keep_bwd_hp, dw_unquantized_ref, rtol=0, atol=0) - if dx2_keep_bwd_hp is not None and dx2_unquantized_ref is not None: - torch.testing.assert_close(dx2_keep_bwd_hp, dx2_unquantized_ref, rtol=0, atol=0) - if db_keep_bwd_hp is not None and db_unquantized_ref is not None: - torch.testing.assert_close(db_keep_bwd_hp, db_unquantized_ref, rtol=0, atol=0) - - -@pytest.mark.parametrize( - "recipe_name", - _quantized_numerics_recipe_list, -) -@pytest.mark.parametrize("input_shape", _bias_activation_shape_cases) -def test_keep_backward_unquantized_fused_bias_activation_matches_masked_linear_backward( - recipe_name: str, - input_shape: tuple[int, ...], -): - # Fused linear op path is based on te_ops.Linear and shares its recipe constraints. - _maybe_skip_unsupported_recipe_module_combo(recipe_name, "ops_linear") - - reset_rng_states() - dtype = torch.bfloat16 - in_features = input_shape[-1] - out_features = 64 - - model_quantized_ref = _make_fused_model("bias_activation", in_features, out_features, dtype) - model_keep_bwd_hp = _make_fused_model("bias_activation", in_features, out_features, dtype) - linear_unquantized_ref = _make_linear_like_module( - "ops_linear", in_features, out_features, dtype, bias=True - ) - - _copy_named_parameters(model_quantized_ref, model_keep_bwd_hp) - _copy_named_parameters(model_keep_bwd_hp[0], linear_unquantized_ref) - - x1 = torch.randn(*input_shape, dtype=dtype, device="cuda") - out_shape = x1.shape[:-1] + (out_features,) - dy = torch.randn(*out_shape, dtype=dtype, device="cuda") - - quantized_ref_recipe = _build_quantized_reference_recipe(recipe_name) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe(recipe_name) - - y_quantized_ref, _, _, _, _ = _run_fused_single_step( - "bias_activation", model_quantized_ref, x1, dy, quantized_ref_recipe - ) - y_keep_bwd_hp, dx1_keep_bwd_hp, _, dw_keep_bwd_hp, db_keep_bwd_hp = _run_fused_single_step( - "bias_activation", model_keep_bwd_hp, x1, dy, keep_bwd_hp_recipe - ) - - # Ensure this test executes the fused path changed by the keep-bwd feature. - fused_ops = model_keep_bwd_hp._module_groups[0]._forward_ops - assert len(fused_ops) >= 1 - assert isinstance(fused_ops[0][0], ForwardLinearBiasActivation) - - # keep-bwd mode should disable backward-activation+bias fusion, while quantized - # reference should still use it. - keep_bwd_backward_ops = model_keep_bwd_hp._module_groups[0]._backward_ops - assert not any(isinstance(op, BackwardActivationBias) for op, _ in keep_bwd_backward_ops) - quantized_ref_backward_ops = model_quantized_ref._module_groups[0]._backward_ops - assert any(isinstance(op, BackwardActivationBias) for op, _ in quantized_ref_backward_ops) - - torch.testing.assert_close( - y_keep_bwd_hp, - y_quantized_ref, - **_fprop_tolerances(recipe_name), - ) - - # In keep-backward-unquantized mode, backward should behave as high-precision linear backward - # given the ReLU mask induced by quantized forward activations. - dy_after_activation = dy * (y_keep_bwd_hp > 0).to(dy.dtype) - _, dx1_expected, dw_expected = _run_single_step( - linear_unquantized_ref, x1, dy_after_activation, None - ) - db_expected = _extract_bias_grad(linear_unquantized_ref) - assert db_keep_bwd_hp is not None - assert db_expected is not None - - torch.testing.assert_close(dx1_keep_bwd_hp, dx1_expected, rtol=0, atol=0) - torch.testing.assert_close(dw_keep_bwd_hp, dw_expected, rtol=0, atol=0) - torch.testing.assert_close(db_keep_bwd_hp, db_expected, rtol=0, atol=0) - - -def test_keep_backward_unquantized_autocast_respects_quantize_forward_flag(): - reset_rng_states() - dtype = torch.bfloat16 - in_features = 64 - out_features = 64 - - module_quantization_disabled = _make_linear_like_module( - "linear", in_features, out_features, dtype, bias=True - ) - module_unquantized_ref = _make_linear_like_module( - "linear", in_features, out_features, dtype, bias=True - ) - _copy_named_parameters(module_quantization_disabled, module_unquantized_ref) - - x = torch.randn(32, in_features, dtype=dtype, device="cuda") - dy = torch.randn(32, out_features, dtype=dtype, device="cuda") - - recipe_no_fwd_quant = recipe.Float8CurrentScaling( - fp8_format=recipe.Format.E4M3, - quantize_forward=False, - quantize_backward=False, - ) - - y_test, dx_test, dw_test = _run_single_step( - module_quantization_disabled, x, dy, recipe_no_fwd_quant - ) - y_ref, dx_ref, dw_ref = _run_single_step(module_unquantized_ref, x, dy, None) - - torch.testing.assert_close(y_test, y_ref, rtol=0, atol=0) - torch.testing.assert_close(dx_test, dx_ref, rtol=0, atol=0) - torch.testing.assert_close(dw_test, dw_ref, rtol=0, atol=0) - bgrad_test = _extract_bias_grad(module_quantization_disabled) - bgrad_ref = _extract_bias_grad(module_unquantized_ref) - assert bgrad_test is not None - assert bgrad_ref is not None - torch.testing.assert_close(bgrad_test, bgrad_ref, rtol=0, atol=0) - - -def test_keep_backward_unquantized_quantize_op_respects_recipe_overrides(): - reset_rng_states() - dtype = torch.bfloat16 - x = torch.randn(32, 64, dtype=dtype, device="cuda") - dy = torch.randn(32, 64, dtype=dtype, device="cuda") - - model_override = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) - model_ref = te_ops.Sequential(te_ops.Quantize(forward=True, backward=True)) - - recipe_no_quant = recipe.Float8CurrentScaling( - fp8_format=recipe.Format.E4M3, - quantize_forward=False, - quantize_backward=False, - ) - y_override, dx_override = _run_quantize_op_single_step(model_override, x, dy, recipe_no_quant) - y_ref, dx_ref = _run_quantize_op_single_step(model_ref, x, dy, None) - - torch.testing.assert_close(y_override, y_ref, rtol=0, atol=0) - torch.testing.assert_close(dx_override, dx_ref, rtol=0, atol=0) - - -def test_keep_backward_unquantized_is_invalid_for_delayed_scaling(): - with pytest.raises( - (AssertionError, ValueError), - match="Delayed scaling does not support quantize_backward=False", - ): - _ = recipe.DelayedScaling() - - -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -def test_keep_backward_unquantized_not_implemented_for_layernorm_mlp(): - reset_rng_states() - layer = te.LayerNormMLP( - hidden_size=64, - ffn_hidden_size=64, - params_dtype=torch.bfloat16, - bias=False, - device="cuda", - ) - x = torch.randn(32, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True) - keep_bwd_hp_recipe = _build_keep_backward_unquantized_recipe("fp8_current_scaling") - - with pytest.raises( - AssertionError, match="NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" - ): - with te.autocast(enabled=True, recipe=keep_bwd_hp_recipe): - _ = layer(x) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 46a19652f1..9058f155c4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -11,6 +11,20 @@ from pydantic.dataclasses import dataclass +_BACKWARD_MODES = ("default", "unquant", "dequant") + + +def _resolve_backward_mode(mode: Optional[str] = None) -> str: + """Return validated backward mode from argument or NVTE_BACKWARD_MODE env.""" + if mode is None: + mode = os.getenv("NVTE_BACKWARD_MODE", "default") + mode = mode.lower() + assert ( + mode in _BACKWARD_MODES + ), f"Invalid NVTE_BACKWARD_MODE value {mode!r}. Supported values are: default|unquant|dequant." + return mode + + class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. @@ -188,11 +202,8 @@ def scaling_factor_compute(amax: Tensor, `LayerNormLinear (BF16 output) -> (cast to FP8 ) FP8 DPA (cast to BF16) -> Linear`. When `fp8_mha = True, fp8_dpa = True`, it becomes `LayerNormLinear (FP8 output) -> FP8 DPA -> Linear`. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. Delayed scaling - always quantizes backward; setting this to False is not supported. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. Delayed scaling only supports `default`. Notes ----- @@ -216,15 +227,14 @@ def scaling_factor_compute(amax: Tensor, reduce_amax: bool = True fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." - assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." + assert ( + self.backward_mode == "default" + ), "Delayed scaling only supports backward_mode=default." def __repr__(self) -> str: return ( @@ -235,8 +245,7 @@ def __repr__(self) -> str: f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -250,10 +259,11 @@ class Float8CurrentScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.HYBRID Controls the FP8 data format used during forward and backward pass. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" @@ -266,14 +276,11 @@ class Float8CurrentScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -287,8 +294,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -315,32 +321,29 @@ class MXFP8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ margin: int = 0 fp8_format: Format = Format.E4M3 fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -369,10 +372,11 @@ class Float8BlockScaling(Recipe): fp8_format : {Format.E4M3, Format.HYBRID}, default = Format.E4M3 Controls the FP8 data format used during forward and backward pass. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ use_f32_scales: bool = os.getenv("NVTE_FP8_BLOCK_SCALING_FP32_SCALES", "0") == "1" @@ -389,10 +393,10 @@ class Float8BlockScaling(Recipe): fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.x_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for x" assert self.w_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for w" assert self.grad_block_scaling_dim in [1, 2], "Only 1D or 2D blocks supported for grad" @@ -412,9 +416,6 @@ def __post_init__(self) -> None: not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." def __repr__(self) -> str: return ( @@ -431,8 +432,7 @@ def __repr__(self) -> str: f"fp8_gemm_wgrad={self.fp8_gemm_wgrad}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) @@ -481,10 +481,11 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ # Configuration envvars @@ -500,15 +501,12 @@ class NVFP4BlockScaling(Recipe): # Not applying quantization to attention for now fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" - assert not ( - not self.quantize_forward and self.quantize_backward - ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -536,8 +534,7 @@ def __repr__(self) -> str: f"fp8_format={str(self.fp8_format).split('.')[1]}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}, " + f"backward_mode={self.backward_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " @@ -569,23 +566,25 @@ class CustomRecipe(Recipe): - forward: "linear_input", "linear_weight", "linear_output" - backward: "linear_grad_output", "linear_grad_input" - quantize_forward : bool, default = True - Whether to quantize tensors in the forward pass. - quantize_backward : bool, default = True - Whether to quantize tensors in the backward pass. + backward_mode : {'default', 'unquant', 'dequant'}, default = 'default' + Backward precision mode. `default` performs quantized backward, + `unquant` keeps original high-precision operands for backward, + and `dequant` dequantizes saved operands to the active high-precision + compute dtype (e.g. BF16/FP16/FP32) for backward. """ qfactory: Callable[..., Any] fp8_dpa: bool = False fp8_mha: bool = False - quantize_forward: bool = True - quantize_backward: bool = not (os.getenv("NVTE_KEEP_BACKWARD_UNQUANTIZED", "0") == "1") + backward_mode: str = field(default_factory=_resolve_backward_mode) + + def __post_init__(self) -> None: + self.backward_mode = _resolve_backward_mode(self.backward_mode) def __repr__(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"qfactory={self.qfactory}, " - f"quantize_forward={self.quantize_forward}, " - f"quantize_backward={self.quantize_backward}" + f"backward_mode={self.backward_mode}" ) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index fe5be68034..1d88575a7d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1135,7 +1135,7 @@ def grad_output_preprocess( grad_output = grad_output.reshape((-1, grad_output.shape[-1])) grad_output = grad_output.contiguous() gather_grad_output = row_parallel_mode and ctx.sequence_parallel - use_fp8_bwd = ctx.fp8 and not ctx.keep_backward_unquantized + use_fp8_bwd = ctx.fp8 and ctx.backward_mode == "default" # Non-FP8 case: bgrad is fused with wgrad for this case. if not use_fp8_bwd and not ctx.debug: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 530e8c2075..95eeee7e88 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -98,11 +98,12 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - if keep_backward_unquantized: - # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode == "unquant": + # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True num_gemms = len(m_splits) @@ -119,10 +120,15 @@ def forward( input_quantizer.set_usage( rowwise=True, columnwise=( - is_grad_enabled and weight_requires_grad and not save_original_input + is_grad_enabled + and weight_requires_grad + and not save_original_input + and backward_mode == "default" ), ) columnwise_usage = is_grad_enabled and inp.requires_grad + if backward_mode in ("unquant", "dequant"): + columnwise_usage = False if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -246,7 +252,12 @@ def forward( else: for inputmat in inputmats: if isinstance(inputmat, QuantizedTensorStorage): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + if backward_mode in ("unquant", "dequant"): + # In dequant mode we should dequantize directly from + # fprop quantized layouts without retargeting usage. + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms @@ -297,7 +308,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -316,9 +327,9 @@ def forward( ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - # keep_backward_unquantized overrides - if keep_backward_unquantized: - ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -422,7 +433,16 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) weights_for_dgrad = weights - if ctx.keep_backward_unquantized: + if ctx.backward_mode == "dequant": + weights_for_dgrad = [ + ( + weight.dequantize(dtype=ctx.activation_dtype) + if isinstance(weight, QuantizedTensorStorage) + else cast_if_needed(weight, ctx.activation_dtype) + ) + for weight in weights + ] + elif ctx.backward_mode == "unquant": weights_for_dgrad = origin_weights # Make sure weights are available in column-wise format # for dgrad computation. @@ -485,6 +505,30 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = torch.split( cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits ) + elif ctx.backward_mode == "dequant": + inputmats_dequant = [] + for m_split, inputmat in zip(ctx.m_splits, inputmats): + if isinstance(inputmat, QuantizedTensorStorage): + if m_split == 0: + # Dequant kernels for some quantized storage formats + # (e.g. MXFP8/Float8BlockScaling) do not accept empty + # M-dimension inputs. For empty grouped splits, materialize + # an explicit empty high-precision matrix instead of invoking + # dequantize(). + inputmats_dequant.append( + torch.empty( + (0, ctx.weights_shape_1), + dtype=ctx.activation_dtype, + device=ctx.device, + ) + ) + else: + inputmats_dequant.append( + inputmat.dequantize(dtype=ctx.activation_dtype) + ) + else: + inputmats_dequant.append(cast_if_needed(inputmat, ctx.activation_dtype)) + inputmats = inputmats_dequant grouped_gemm_wgrad = functools.partial( general_grouped_gemm, quantization_params=ctx.grad_weight_quantizers, @@ -1063,6 +1107,12 @@ def _get_quantizers(self): for i in range(self.num_gemms): grad_output_quantizers[i].internal = True grad_output_quantizers[i].optimize_for_gemm = True + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + for input_quantizer in input_quantizers: + input_quantizer.optimize_for_gemm = False + for grad_output_quantizer in grad_output_quantizers: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 187fd70f92..0c6e960d22 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -141,9 +141,10 @@ def forward( symmetric_ar_type, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # NVTX label for profiling nvtx_label = "transformer_engine._LayerNormLinear.forward" @@ -205,7 +206,7 @@ def forward( raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( rowwise=True, - columnwise=backward_needs_input and not keep_backward_unquantized, + columnwise=backward_needs_input and backward_mode == "default", ) if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data @@ -219,7 +220,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not keep_backward_unquantized + and backward_mode == "default" and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) @@ -243,7 +244,7 @@ def forward( ln_out_return = None if return_layernorm_output or return_layernorm_output_gathered: ln_out_return = ln_out - ln_out_hp = ln_out if keep_backward_unquantized else None + ln_out_hp = ln_out if backward_mode == "unquant" else None # ------------------------------------------------------ # Prepare GEMM input tensor @@ -304,7 +305,10 @@ def forward( if is_weight_param_quantized: weight_quantizer = weight._quantizer elif weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + weight_quantizer.set_usage( + rowwise=True, + columnwise=is_grad_enabled and backward_mode == "default", + ) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -418,7 +422,7 @@ def forward( if is_grad_enabled: ln_out_to_save = ln_out - if keep_backward_unquantized: + if backward_mode == "unquant": ln_out_to_save = ln_out_hp ctx.weight_quantizer = weight_quantizer ctx.ln_out_needs_gather = ( @@ -426,7 +430,7 @@ def forward( ) # Input with column-wise usage is needed for wgrad GEMM. - if backward_needs_input and not keep_backward_unquantized: + if backward_needs_input and backward_mode == "default": if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data @@ -504,7 +508,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch @@ -535,9 +539,9 @@ def forward( ctx.wgrad_store = wgrad_store ctx.debug = debug - # keep_backward_unquantized overrides - if keep_backward_unquantized: - ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -686,6 +690,11 @@ def backward( # -------------------------------------------------- ln_out_total = None ln_out_total_work = None + if ctx.backward_mode == "dequant": + if isinstance(ln_out, QuantizedTensorStorage): + ln_out = ln_out.dequantize(dtype=ctx.activation_dtype) + else: + ln_out = cast_if_needed(ln_out, ctx.activation_dtype) if ctx.ln_out_needs_gather: quantizer = None if ctx.input_quantizer is not None and ctx.fp8: @@ -758,7 +767,12 @@ def backward( # Note: dx = dy * w nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight - if ctx.keep_backward_unquantized: + if ctx.backward_mode == "dequant": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_mode == "unquant": weight_for_dgrad = origin_weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, @@ -1658,6 +1672,11 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ae80694587..d158ecadb4 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -235,12 +235,13 @@ def _forward( debug, recompute_for_bwd, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" assert ( - not keep_backward_unquantized - ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" + backward_mode == "default" + ), "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP" # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: @@ -787,7 +788,7 @@ def _forward( ctx.fc2_main_grad_func = lambda: fc2_weight.main_grad ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.fc1_grad_input_quantizer = fc1_grad_input_quantizer ctx.fc1_grad_weight_quantizer = fc1_grad_weight_quantizer ctx.fc1_grad_output_quantizer = fc1_grad_output_quantizer diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7d960102ec..34342f011f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -129,11 +129,12 @@ def forward( save_original_input, debug, ) = non_tensor_args - keep_backward_unquantized = fp8 and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - if keep_backward_unquantized: - # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used + if fp8: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode == "unquant": + # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True # NVTX label for profiling @@ -195,7 +196,10 @@ def forward( raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorStorage) and not custom: own_quantized_input = True - input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) + input_quantizer.set_usage( + rowwise=True, + columnwise=backward_needs_input and backward_mode == "default", + ) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): @@ -237,7 +241,12 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input + rowwise=True, + columnwise=( + backward_needs_input + and not save_original_input + and backward_mode == "default" + ), ) inputmat = input_quantizer(inputmat) own_quantized_input = True @@ -261,6 +270,8 @@ def forward( # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad + if backward_mode in ("unquant", "dequant"): + columnwise_usage = False if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -394,7 +405,11 @@ def forward( and own_quantized_input and isinstance(inputmat, QuantizedTensorStorage) ): - if ( + if backward_mode in ("unquant", "dequant"): + # In dequant mode we should dequantize directly from the + # fprop quantized tensor layout without retargeting usage. + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + elif ( ctx.backward_input_needs_gather and weight_quantizer.supports_only_rowwise_all_gather() ): @@ -449,7 +464,7 @@ def forward( ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.keep_backward_unquantized = keep_backward_unquantized + ctx.backward_mode = backward_mode ctx.input_quantizer = input_quantizer ctx.grad_input_quantizer = grad_input_quantizer ctx.grad_weight_quantizer = grad_weight_quantizer @@ -493,9 +508,9 @@ def forward( FP8GlobalStateManager.IS_FIRST_FP8_MODULE = _first_fp8_module ctx.wgrad_store = wgrad_store - # keep_backward_unquantized overrides - if keep_backward_unquantized: - ctx.fp8 = ctx.fp8 and not keep_backward_unquantized + # Non-quantized backward mode overrides + if backward_mode in ("unquant", "dequant"): + ctx.fp8 = False ctx.ub_overlap_ag = False ctx.ub_overlap_rs_dgrad = False ctx.ub_bulk_dgrad = False @@ -741,7 +756,12 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if ctx.keep_backward_unquantized: + if ctx.backward_mode == "dequant": + if isinstance(weight_for_dgrad, QuantizedTensorStorage): + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + else: + weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) + elif ctx.backward_mode == "unquant": weight_for_dgrad = weight gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, @@ -1519,6 +1539,11 @@ def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): grad_output_quantizer.optimize_for_gemm = True if fp8_grad: grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizer, weight_quantizer, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 15a6815d2e..7f21cd9331 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -332,10 +332,9 @@ def pre_fuser_forward(self, *, requires_grad: bool) -> None: # Note: We cache the quantized input for backward pass, # but discard the quantized weights. weight_requires_grad = requires_grad and self.weight.requires_grad - keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - columnwise_usage = weight_requires_grad and not keep_backward_unquantized + columnwise_usage = weight_requires_grad + if FP8GlobalStateManager.get_fp8_recipe().backward_mode in ("unquant", "dequant"): + columnwise_usage = False input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) grad_output_quantizer = self.get_quantizer("backward", 0) @@ -359,6 +358,13 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: grad_output_quantizer.internal = True if not (self.tensor_parallel_mode == "row" and self.sequence_parallel): grad_output_quantizer.optimize_for_gemm = True + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): + if input_quantizer is not None: + input_quantizer.optimize_for_gemm = False + if grad_output_quantizer is not None: + grad_output_quantizer.optimize_for_gemm = False # Configure weight quantizer # Note: This function may be called in base class constructor, @@ -424,7 +430,7 @@ def _functional_forward( tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, sequence_parallel: bool = False, with_quantized_compute: bool = False, - keep_backward_unquantized: bool = False, + backward_mode: str = "default", input_quantizer: Optional[Quantizer] = None, weight_quantizer: Optional[Quantizer] = None, output_quantizer: Optional[Quantizer] = None, @@ -464,8 +470,8 @@ def _functional_forward( distributing along inner dimension (embedding dim) with_quantized_compute: bool, default = False Whether to perform compute with quantized data. - keep_backward_unquantized: bool, default = `False` - Whether to skip quantized backward and use high precision. + backward_mode: {`"default"`, `"unquant"`, `"dequant"`}, default = `"default"` + Backward-mode policy for quantized compute. input_quantizer: Quantizer, optional Builder class for quantized input tensor. weight_quantizer: Quantizer, optional @@ -519,7 +525,7 @@ def _functional_forward( raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage( rowwise=True, - columnwise=weight_requires_grad and not keep_backward_unquantized, + columnwise=weight_requires_grad and backward_mode == "default", ) if with_x_all_gather: input_quantizer.set_usage(columnwise=False) @@ -554,7 +560,7 @@ def _functional_forward( raise ValueError("Missing quantizer for weight tensor") weight_quantizer.set_usage( rowwise=True, - columnwise=input_requires_grad and not keep_backward_unquantized, + columnwise=input_requires_grad and backward_mode == "default", ) w = weight_quantizer(w) @@ -628,7 +634,7 @@ def _functional_forward( w is not weight and with_quantized_compute and is_quantized_tensor(w) - and not keep_backward_unquantized + and backward_mode == "default" ): w.update_usage(rowwise_usage=False, columnwise_usage=True) else: @@ -639,7 +645,7 @@ def _functional_forward( if ( with_quantized_compute and is_quantized_tensor(x_local) - and not keep_backward_unquantized + and backward_mode == "default" ): if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data @@ -990,9 +996,10 @@ def op_forward( grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -1009,7 +1016,7 @@ def op_forward( tensor_parallel_group=self.tensor_parallel_group, sequence_parallel=self.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -1019,12 +1026,13 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = self.weight if keep_backward_unquantized else w + saved_input = input_ if backward_mode == "unquant" else x_local + saved_weight = self.weight if backward_mode == "unquant" else w if is_cpu_offload_enabled(): mark_activation_offload(saved_input) ctx.save_for_backward(saved_input, saved_weight) - ctx.with_quantized_compute = with_quantized_compute and not keep_backward_unquantized + ctx.with_quantized_compute = with_quantized_compute and backward_mode == "default" + ctx.backward_mode = backward_mode ctx.input_quantizer = input_quantizer ctx.weight_quantizer = weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 8bcd84b441..ad147a8d85 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -124,12 +124,11 @@ def op_forward( b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) if ctx.requires_grad: - keep_backward_unquantized = FP8GlobalStateManager.is_fp8_enabled() and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) - ctx.grad_input_quantizer = ( - None if keep_backward_unquantized else prev_op_grad_output_quantizer - ) + ctx.grad_input_quantizer = prev_op_grad_output_quantizer + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.backward_mode in ("unquant", "dequant"): + ctx.grad_input_quantizer = None return x + b diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index b2a36d1daa..c5474c18a0 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -59,14 +59,10 @@ def op_forward( quantize_forward = fp8_enabled and self._quantize_forward quantize_backward = fp8_enabled and self._quantize_backward - # Recipe quantize overrides - if FP8GlobalStateManager.get_fp8_recipe() is not None: - quantize_forward = ( - quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward - ) - quantize_backward = ( - quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + # Backward quantization is controlled by recipe backward mode. + if fp8_enabled: + recipe = FP8GlobalStateManager.get_fp8_recipe() + quantize_backward = quantize_backward and recipe.backward_mode == "default" # Quantize if needed out = input_ diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 395a9dbd67..7b3025c03e 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -105,8 +105,8 @@ def fuse_backward_ops( """ # Check if recipe supports bias activation fusion. - # keep-backward-unquantized mode should use unfused backward ops. - if recipe is None or not recipe.quantize_backward: + # unquant/dequant backward modes should use unfused backward ops. + if recipe is None or recipe.backward_mode in ("unquant", "dequant"): return ops # Scan through ops, fusing if possible diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 42f459a41e..7584891384 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -92,9 +92,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -112,7 +113,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -122,14 +123,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = linear_op.weight if keep_backward_unquantized else w + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized + with_quantized_compute and backward_mode == "default" ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -138,9 +141,9 @@ def fuser_forward( linear_op_ctx.input_requires_grad = input_requires_grad linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: - bias_op_ctx.grad_input_quantizer = ( - None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() - ) + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() + if backward_mode in ("unquant", "dequant"): + bias_op_ctx.grad_input_quantizer = None return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 75d58fd5cc..6935330f4e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -86,9 +86,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -109,7 +110,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -119,14 +120,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = linear_op.weight if keep_backward_unquantized else w + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized + with_quantized_compute and backward_mode == "default" ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer @@ -136,7 +139,7 @@ def fuser_forward( linear_op_ctx.weight_requires_grad = weight_requires_grad if bias_op is not None and bias_op_ctx.requires_grad: bias_op_ctx.grad_input_quantizer = ( - None if keep_backward_unquantized else linear_op.get_grad_output_quantizer() + None if backward_mode != "default" else linear_op.get_grad_output_quantizer() ) return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index dfdd11a231..2358140c88 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -65,9 +65,10 @@ def fuser_forward( grad_output_quantizer = linear_op.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - keep_backward_unquantized = with_quantized_compute and ( - not FP8GlobalStateManager.get_fp8_recipe().quantize_backward - ) + if with_quantized_compute: + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" # Get extra input tensor for add operation extra_input = basic_op_extra_inputs[2][0] @@ -90,7 +91,7 @@ def fuser_forward( tensor_parallel_group=linear_op.tensor_parallel_group, sequence_parallel=linear_op.sequence_parallel, with_quantized_compute=with_quantized_compute, - keep_backward_unquantized=keep_backward_unquantized, + backward_mode=backward_mode, input_quantizer=input_quantizer, weight_quantizer=weight_quantizer, output_quantizer=output_quantizer, @@ -100,14 +101,16 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: - saved_input = input_ if keep_backward_unquantized else x_local - saved_weight = linear_op.weight if keep_backward_unquantized else w + saved_input, saved_weight = x_local, w + if backward_mode == "unquant": + saved_input, saved_weight = input_, linear_op.weight if is_cpu_offload_enabled(): mark_activation_offload(saved_input) linear_op_ctx.save_for_backward(saved_input, saved_weight) linear_op_ctx.with_quantized_compute = ( - with_quantized_compute and not keep_backward_unquantized + with_quantized_compute and backward_mode == "default" ) + linear_op_ctx.backward_mode = backward_mode linear_op_ctx.input_quantizer = input_quantizer linear_op_ctx.weight_quantizer = weight_quantizer linear_op_ctx.grad_output_quantizer = grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 0d3e1d0416..54411f650d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -388,6 +388,19 @@ def fuse_forward_ops( """ + # Disable Userbuffers for non-quantized backward modes. + # In unquant/dequant modes we want to avoid all UB-specific overlap + # paths and run through the standard non-UB operator sequence instead. + recipe = unused.get("recipe", None) + if recipe is not None: + backward_mode = recipe.backward_mode + elif FP8GlobalStateManager.is_fp8_enabled(): + backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode + else: + backward_mode = "default" + if backward_mode in ("unquant", "dequant"): + return ops + # Return immediately if environment is not distributed if not torch.distributed.is_initialized() or torch.distributed.get_world_size() == 1: return ops diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index bd3bc94b60..45d0d68684 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -338,6 +338,7 @@ def __init__( # Cache and detect change of state relevant for fusing operations self.recipe_type = None self.first_op_requiring_backward = 0 + self.backward_mode = "default" self._last_amax_history_len = 0 # Flatten list of parameters @@ -414,9 +415,14 @@ def maybe_fuse_ops( # Early exit if fusion parameters haven't changed need_reset = False recipe_type = type(recipe) - fusion_params = (recipe_type, first_op_requiring_backward) - if fusion_params != (self.recipe_type, self.first_op_requiring_backward): - # Recipe type or grad requirmenets have changed + backward_mode = recipe.backward_mode if recipe is not None else "default" + fusion_params = (recipe_type, first_op_requiring_backward, backward_mode) + if fusion_params != ( + self.recipe_type, + self.first_op_requiring_backward, + self.backward_mode, + ): + # Recipe type, backward mode, or grad requirements have changed need_reset = True elif ( recipe is not None @@ -450,7 +456,7 @@ def maybe_fuse_ops( ) # Save current fusion params - self.recipe_type, self.first_op_requiring_backward = fusion_params + self.recipe_type, self.first_op_requiring_backward, self.backward_mode = fusion_params # Save amax history length if isinstance(recipe, DelayedScaling): diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 00196c584f..eba547afb0 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -842,15 +842,14 @@ def autocast( are reduced at the end of each training step. """ - effective_enabled = enabled and getattr(recipe, "quantize_forward", True) - if effective_enabled: + if enabled: check_recipe_support(recipe) # Save current state so we always restore it on exit. fp8_state = FP8GlobalStateManager.get_autocast_state() FP8GlobalStateManager.autocast_enter( - enabled=effective_enabled, + enabled=enabled, calibrating=calibrating, fp8_recipe=recipe, fp8_group=amax_reduction_group, @@ -860,7 +859,7 @@ def autocast( yield finally: FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(effective_enabled, _graph=_graph) + FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: From 0dee809e9a5cfb9e94f3ed93c933fa3557b7188e Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Wed, 25 Feb 2026 12:59:01 -0800 Subject: [PATCH 45/45] Fix override and clean up Signed-off-by: Ziang Li --- transformer_engine/pytorch/module/grouped_linear.py | 6 +++--- transformer_engine/pytorch/module/layernorm_mlp.py | 7 ++++--- transformer_engine/pytorch/module/linear.py | 1 - 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 95eeee7e88..615ad9df56 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -103,7 +103,6 @@ def forward( else: backward_mode = "default" if backward_mode == "unquant": - # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True num_gemms = len(m_splits) @@ -1111,8 +1110,9 @@ def _get_quantizers(self): if fp8_recipe.backward_mode == "dequant" and (fp8_recipe.mxfp8() or fp8_recipe.nvfp4()): for input_quantizer in input_quantizers: input_quantizer.optimize_for_gemm = False - for grad_output_quantizer in grad_output_quantizers: - grad_output_quantizer.optimize_for_gemm = False + if torch.is_grad_enabled(): + for grad_output_quantizer in grad_output_quantizers: + grad_output_quantizer.optimize_for_gemm = False return ( input_quantizers, weight_quantizers, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d158ecadb4..0ba5a016cb 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -239,9 +239,10 @@ def _forward( backward_mode = FP8GlobalStateManager.get_fp8_recipe().backward_mode else: backward_mode = "default" - assert ( - backward_mode == "default" - ), "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP" + assert backward_mode == "default", ( + "NVTE_BACKWARD_MODE=unquant/dequant is not implemented in LayerNormMLP. " + "Replace LayerNormMLP with LayerNormLinear + Linear to enable unquant/dequant backward." + ) # if grad is enabled and this is not the bwd stage, we must save this so bwd knows which path to take if is_grad_enabled and not recompute_for_bwd: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 34342f011f..47bd633fe5 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -134,7 +134,6 @@ def forward( else: backward_mode = "default" if backward_mode == "unquant": - # Note, NVTE_BACKWARD_MODE=unquant is ignored when delayed scaling is used. save_original_input = True # NVTX label for profiling