Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions megatron/core/fusions/fused_bias_swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,36 @@

from megatron.core.jit import jit_fuser
from megatron.core.utils import nvtx_decorator
from megatron.core.utils import is_te_min_version

try:
import transformer_engine # pylint: disable=unused-import
import transformer_engine_torch as tex
from transformer_engine.pytorch import Float8CurrentScalingQuantizer

HAVE_TE = is_te_min_version("2.2.0.dev0")

except ModuleNotFoundError:

HAVE_TE = False

###### BIAS SWIGLU FUSION/ NO AUTOGRAD ################


def quantize(input):
if HAVE_TE:
quantizer = Float8CurrentScalingQuantizer(
fp8_dtype=tex.DType.kFloat8E4M3, device=input.device, columnwise=False
)
return quantizer(input)
else:
return input.to(torch.float8_e4m3fn)


def dequantize(input, ori_input_dtype):
return input.dequantize(dtype=ori_input_dtype) if HAVE_TE else input.to(ori_input_dtype)


@jit_fuser
def swiglu(y):
"""Performs SwiGLU (Swish-Gated Linear Unit) activation function.
Expand Down Expand Up @@ -114,7 +140,7 @@ def forward(ctx, input, bias, fp8_input_store, cpu_offload_input):
Returns:
torch.Tensor: Result of applying bias addition followed by SwiGLU activation.
"""
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
input_for_backward = quantize(input) if fp8_input_store else input
if cpu_offload_input:
input_for_backward.activation_offloading = True
bias.activation_offloading = True
Expand All @@ -139,7 +165,7 @@ def backward(ctx, grad_output):
- None for fp8_input_store parameter
"""
input, bias = ctx.saved_tensors
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
input = dequantize(input, ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = bias_swiglu_back(grad_output, input, bias)
return tmp, tmp, None, None

Expand All @@ -160,7 +186,7 @@ def forward(ctx, input, fp8_input_store, cpu_offload_input):
Returns:
torch.Tensor: Result of applying SwiGLU activation.
"""
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
input_for_backward = quantize(input) if fp8_input_store else input
if cpu_offload_input:
input_for_backward.activation_offloading = True
ctx.save_for_backward(input_for_backward)
Expand All @@ -183,7 +209,7 @@ def backward(ctx, grad_output):
- None for fp8_input_store parameter
"""
input = ctx.saved_tensors[0]
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
input = dequantize(input, ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp = swiglu_back(grad_output, input)
return tmp, None, None

Expand All @@ -192,7 +218,7 @@ class WeightedSwiGLUFunction(torch.autograd.Function):
@staticmethod
# bias is an optional argument
def forward(ctx, input, weights, fp8_input_store):
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
input_for_backward = quantize(input) if fp8_input_store else input
ctx.save_for_backward(input_for_backward, weights)
ctx.ori_input_dtype = input.dtype
ctx.fp8_input_store = fp8_input_store
Expand All @@ -201,7 +227,7 @@ def forward(ctx, input, weights, fp8_input_store):
@staticmethod
def backward(ctx, grad_output):
input, weights = ctx.saved_tensors
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
input = dequantize(input, ctx.ori_input_dtype) if ctx.fp8_input_store else input
tmp, wgrad = weighted_swiglu_back(grad_output, input, weights)
return tmp, wgrad, None

Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2098,6 +2098,8 @@ def _add_training_args(parser):
help='The communicator group names to use high priority streams.')
group.add_argument('--disable-jit-fuser', action='store_true',
help='Disable the JIT fuser.')
group.add_argument('--activation-func-fp8-input-store', action='store_true',
help='Store swiglu inputs in fp8 to save activation memory.')

return parser

Expand Down