diff --git a/megatron/core/fusions/fused_bias_swiglu.py b/megatron/core/fusions/fused_bias_swiglu.py index 632470876c9..c080a2d0db3 100644 --- a/megatron/core/fusions/fused_bias_swiglu.py +++ b/megatron/core/fusions/fused_bias_swiglu.py @@ -8,10 +8,36 @@ from megatron.core.jit import jit_fuser from megatron.core.utils import nvtx_decorator +from megatron.core.utils import is_te_min_version + +try: + import transformer_engine # pylint: disable=unused-import + import transformer_engine_torch as tex + from transformer_engine.pytorch import Float8CurrentScalingQuantizer + + HAVE_TE = is_te_min_version("2.2.0.dev0") + +except ModuleNotFoundError: + + HAVE_TE = False ###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################ +def quantize(input): + if HAVE_TE: + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, device=input.device, columnwise=False + ) + return quantizer(input) + else: + return input.to(torch.float8_e4m3fn) + + +def dequantize(input, ori_input_dtype): + return input.dequantize(dtype=ori_input_dtype) if HAVE_TE else input.to(ori_input_dtype) + + @jit_fuser def swiglu(y): """Performs SwiGLU (Swish-Gated Linear Unit) activation function. @@ -114,7 +140,7 @@ def forward(ctx, input, bias, fp8_input_store, cpu_offload_input): Returns: torch.Tensor: Result of applying bias addition followed by SwiGLU activation. """ - input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + input_for_backward = quantize(input) if fp8_input_store else input if cpu_offload_input: input_for_backward.activation_offloading = True bias.activation_offloading = True @@ -139,7 +165,7 @@ def backward(ctx, grad_output): - None for fp8_input_store parameter """ input, bias = ctx.saved_tensors - input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + input = dequantize(input, ctx.ori_input_dtype) if ctx.fp8_input_store else input tmp = bias_swiglu_back(grad_output, input, bias) return tmp, tmp, None, None @@ -160,7 +186,7 @@ def forward(ctx, input, fp8_input_store, cpu_offload_input): Returns: torch.Tensor: Result of applying SwiGLU activation. """ - input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + input_for_backward = quantize(input) if fp8_input_store else input if cpu_offload_input: input_for_backward.activation_offloading = True ctx.save_for_backward(input_for_backward) @@ -183,7 +209,7 @@ def backward(ctx, grad_output): - None for fp8_input_store parameter """ input = ctx.saved_tensors[0] - input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + input = dequantize(input, ctx.ori_input_dtype) if ctx.fp8_input_store else input tmp = swiglu_back(grad_output, input) return tmp, None, None @@ -192,7 +218,7 @@ class WeightedSwiGLUFunction(torch.autograd.Function): @staticmethod # bias is an optional argument def forward(ctx, input, weights, fp8_input_store): - input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + input_for_backward = quantize(input) if fp8_input_store else input ctx.save_for_backward(input_for_backward, weights) ctx.ori_input_dtype = input.dtype ctx.fp8_input_store = fp8_input_store @@ -201,7 +227,7 @@ def forward(ctx, input, weights, fp8_input_store): @staticmethod def backward(ctx, grad_output): input, weights = ctx.saved_tensors - input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + input = dequantize(input, ctx.ori_input_dtype) if ctx.fp8_input_store else input tmp, wgrad = weighted_swiglu_back(grad_output, input, weights) return tmp, wgrad, None diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5749d20a4ca..6ee96495382 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2098,6 +2098,8 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') + group.add_argument('--activation-func-fp8-input-store', action='store_true', + help='Store swiglu inputs in fp8 to save activation memory.') return parser