Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
misspell -error .
cpp:
name: CPP code lint
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v2
Expand All @@ -33,7 +33,7 @@ jobs:
make cpplint
md:
name: Markdown code lint
runs-on: ubuntu-20.04
runs-on: ubuntu-22.04
steps:
- name: Checkout
uses: actions/checkout@v2
Expand Down
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ lint: cpplint mdlint
postinstall:
cd msamp/operators/dist_op && bash build.sh && cd -
cd msamp/operators/arithmetic && pip install -v . && cd -
cd msamp/operators/fp4_quantize && pip install -v . && cd -
cd msamp/optim && pip install -v . && cd -
33 changes: 31 additions & 2 deletions msamp/megatron/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,19 @@
from msamp.common.tensor import ScalingTensor
from msamp.operators.gemm import Gemm

import os

""" Below are the environment variables to control the FP4 quantization behavior, based on https://arxiv.org/abs/2501.17116
Using 'MSAMP_USE_WEIGHT_SIMULATE_FP4' to control if weight quantization is used.
Using 'MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR' to control if DGE (Differentiable Gradient Estimator) is used.
Using 'MSAMP_USE_ACTIVATION_SIMULATE_FP4' to control if activation quantization is used.
"""
MSAMP_USE_WEIGHT_SIMULATE_FP4 = bool(int(os.getenv('MSAMP_USE_WEIGHT_SIMULATE_FP4', 0)))
MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR = bool(int(os.getenv('MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR', 0)))
MSAMP_USE_ACTIVATION_SIMULATE_FP4 = bool(int(os.getenv('MSAMP_USE_ACTIVATION_SIMULATE_FP4', 0)))

from msamp.operators.fp4_quantize import FP4_QUANTIZER


class FP8LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""A linear function with FP8 support, grad accumulation and async communication."""
Expand Down Expand Up @@ -50,19 +63,32 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_a

old_meta_group = input_meta.group
input_meta.group = tp_group
input_fp8 = input.cast(Dtypes.kfloat8_e4m3, meta=input_meta, sync=sequence_parallel)
if MSAMP_USE_ACTIVATION_SIMULATE_FP4:
fp4_input_in_float = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(input.bfloat16(), format='e2m1', nan_existed=False, token_wise=True, outlier_clip=True, clip_threshold=0.99)
input_fp8 = fp4_input_in_float.cast(Dtypes.kfloat8_e4m3, meta=input_meta, sync=sequence_parallel)
else:
input_fp8 = input.cast(Dtypes.kfloat8_e4m3, meta=input_meta, sync=sequence_parallel)
input_meta.group = old_meta_group

input_fp8.requires_grad = input.requires_grad
input = input_fp8.value

weight_fp8 = weight.cast(Dtypes.kfloat8_e4m3)
if MSAMP_USE_WEIGHT_SIMULATE_FP4:
if MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR:
fp4_weight_in_float, scaled_w = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(weight.bfloat16(), format='e2m1', nan_existed=False, channel_wise=True, return_scaled_input_for_bwd=True)
else:
fp4_weight_in_float = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(weight.bfloat16(), format='e2m1', nan_existed=False, channel_wise=True)
weight_fp8 = fp4_weight_in_float.cast(Dtypes.kfloat8_e4m3)
else:
weight_fp8 = weight.cast(Dtypes.kfloat8_e4m3)
weight_fp8.requires_grad = weight.requires_grad

# save tensors
ctx.input_fp8 = input_fp8
ctx.weight_fp8 = weight_fp8
ctx.weight = weight
if MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR:
ctx.save_for_backward(scaled_w)

dim_size = list(input.size())
if sequence_parallel:
Expand Down Expand Up @@ -175,6 +201,9 @@ def backward(ctx, grad_output):
wgrad_qtype,
use_split_accumulator=True,
)
if MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR:
scaled_w = ctx.saved_tensors[0]
grad_weight.mul_(FP4_QUANTIZER.apply_DGE_item(scaled_w))

grad_bias = grad_output.sum(dim=0) if use_bias else None

Expand Down
8 changes: 8 additions & 0 deletions msamp/operators/fp4_quantize/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""Exposes the interface of MS-AMP FP4 Quantize module."""

from msamp.operators.fp4_quantize.fp4_quantize import FP4_QUANTIZER

__all__ = ['FP4_QUANTIZER']
186 changes: 186 additions & 0 deletions msamp/operators/fp4_quantize/fp4_quantize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

"""fp4_quantize module.
Algorithm based on https://arxiv.org/abs/2501.17116. Provided basic python interface of FP4 quantization
and DGE (Differentiable Gradient Estimator) for more accurate gradient update in FP4 training.
"""

import torch
from typing import Literal

from msamp.common.tensor import ScalingMeta

import msamp_quantize


class FP4_QUANTIZER:
"""FP4 Quantization operator. Algorithm based on https://arxiv.org/abs/2501.17116."""
@staticmethod
def apply_DGE_item(
input_tensor: torch.Tensor,
k: float = 5.0,
power_clamp_max: float = 3.0
) -> torch.Tensor:
"""
Apply DGE item to input tensor. Note that this function is fixed to E2M1 format with no NaN.
DGE: Abbreviation of the method 'Differentiable Gradient Estimator' for more accurate gradient update in FP4 training.

Args:
input (torch.Tensor): input tensor.
k (float): parameter k to determine the sharpness of the differentiable quantization estimator.
power_clamp_max (float): parameter power_clamp_max to restrict the amplitude of the estimated gradient.
Returns:
torch.Tensor: output tensor.
"""
if not (input_tensor.is_cuda and input_tensor.is_contiguous):
raise ValueError('The input tensor is not in cuda memory or contiguous.')
if not (input_tensor.dtype == torch.bfloat16):
raise ValueError('The input tensor is not in bfloat16.')

output_tensor = torch.zeros_like(input_tensor)
msamp_quantize.launch_differentiable_quantize_derivative(input_tensor, output_tensor, k, power_clamp_max, torch.numel(input_tensor))
return output_tensor


@staticmethod
def _apply_quantile_clipping(
input: torch.Tensor,
clip_threshold: float = 0.99,
channel_wise: bool = False,
token_wise: bool = False,
return_residual: bool = False,
) -> tuple:
'''
Apply quantile clipping to the input tensor.

Args:
input (torch.Tensor): input tensor.
clip_threshold (float): threshold for quantile clipping. Default is 0.99.
channel_wise (bool): whether to apply clipping through channel dimension. Default is False.
token_wise (bool): whether to apply clipping through token dimension. Default is False.
return_residual (bool): whether to return the residual. Default is False.
Returns:
tuple: output tensor and residual tensor (if return_residual is True).
'''
float_input = input.float() if input.dtype != torch.float32 else input

if channel_wise:
sorted_tensor = torch.sort(input, dim=0).values
lower_index = int((1 - clip_threshold) * sorted_tensor.size(0))
upper_index = int(clip_threshold * sorted_tensor.size(0))

lower_bound = sorted_tensor[lower_index:lower_index+1, :]
upper_bound = sorted_tensor[upper_index:upper_index+1, :]

output = torch.clamp(input, min=lower_bound, max=upper_bound)

elif token_wise:
sorted_tensor = torch.sort(input, dim=1).values
lower_index = int((1 - clip_threshold) * sorted_tensor.size(1))
upper_index = int(clip_threshold * sorted_tensor.size(1))

lower_bound = sorted_tensor[:, lower_index:lower_index+1]
upper_bound = sorted_tensor[:, upper_index:upper_index+1]

output = torch.clamp(input, min=lower_bound, max=upper_bound)

else:
sorted_tensor = torch.sort(float_input.view(-1))[0]
lower_index = int((1 - clip_threshold) * sorted_tensor.size(0))
upper_index = int(clip_threshold * sorted_tensor.size(0))

lower_bound = sorted_tensor[lower_index:lower_index+1]
upper_bound = sorted_tensor[upper_index:upper_index+1]

output = torch.clamp(input, min=lower_bound, max=upper_bound)

output = output.to(input.dtype)
if return_residual:
return output, input - output
else:
return output, None


@staticmethod
def quantize_simulate_fp4_in_bf16(
input_tensor: torch.Tensor,
format: Literal['e2m1', 'e1m2'] = 'e1m2',
nan_existed: bool = False,
channel_wise: bool = False,
token_wise: bool = False,
outlier_clip: bool = False,
clip_threshold: float = 0.99,
residual_compensation: bool = False,
return_scaled_input_for_bwd: bool = False,
) -> torch.Tensor:
"""
Quantize high precision tensor to FP4 tensor.

Args:
input_tensor (torch.Tensor): high precision tensor to quantize. Note that the input tensor should be in cuda memory and bfloat16 dtype.
format (Literal['e2m1', 'e1m2']): format of the quantized tensor. Default is 'e1m2'.
nan_existed (bool): whether NaN value exists in the input tensor. Default is False.
channel_wise (bool): whether to quantize the input tensor through channel dimension. Default is False.
token_wise (bool): whether to quantize the input tensor through token dimension. Default is False.
outlier_clip (bool): whether to apply outlier clipping to the input tensor. Default is False.
clip_threshold (float): threshold for outlier clipping. Default is 0.99.
residual_compensation (bool): whether to add residual back to the quantized tensor. Default is False.
return_scaled_input_for_bwd (bool): whether to return scaled input tensor for backward computation. Default is False.

Note: param 'nan_existed' claimed but needn't to be used (to keep API consistent with other functions).
Returns:
torch.Tensor: simulted FP4-quantied tensor, but still in bfloat16 dtype.
"""
if not (input_tensor.is_cuda and input_tensor.is_contiguous):
raise ValueError('The input tensor is not in cuda memory or contiguous.')
if not (input_tensor.dtype == torch.bfloat16):
raise ValueError('The input tensor is not in bfloat16.')

# handle tensor shape for channel_wise or token_wise quantization
shape = input_tensor.shape
assert not (channel_wise and token_wise), f"channel_wise and token_wise cannot be True at the same time."
if (channel_wise or token_wise) and len(shape) != 2:
dim = shape[-1]
input_tensor = input_tensor.reshape(-1, dim)

# handle outlier clipping
if outlier_clip:
input_tensor, residual = FP4_QUANTIZER._apply_quantile_clipping(input_tensor, clip_threshold, channel_wise, token_wise, return_residual=residual_compensation)

# get amax
if channel_wise:
amax = input_tensor.abs().max(dim=0, keepdim=True)[0] # channel-wise max value
scale = torch.ones((1, 1), dtype=input_tensor.dtype, device='cuda') # 2-D tensor shape
elif token_wise:
amax = input_tensor.abs().max(dim=1, keepdim=True)[0] # token-wise max value
scale = torch.ones((1, 1), dtype=input_tensor.dtype, device='cuda') # 2-D tensor shape
else:
amax = input_tensor.abs().max()
scale = torch.ones((), dtype=input_tensor.dtype, device='cuda')
# compute scaling factor
fp_max = 6.0 if format == 'e2m1' else 7.0 # Fixed. For e1m2, actually it is 3.5, but we *2 for directly round()
margin = 0
sf = ScalingMeta.compute_scaling_factor(amax, scale, fp_max, margin)

# quantize
scaled_input = input_tensor * sf # this * operation can handle matrix-tensor broadcasting. For example, (3, 4) * (4,) -> (3, 4)
if format == 'e2m1':
output_tensor = torch.zeros_like(scaled_input)
msamp_quantize.quantize_bf16(scaled_input, output_tensor, torch.numel(scaled_input))
else:
output_tensor = torch.round(scaled_input)
output_tensor.div_(sf) # this .div_() method can also handle matrix-tensor broadcasting
if residual_compensation:
output_tensor = output_tensor + residual
output_tensor.requires_grad = input_tensor.requires_grad

# reshape output tensor to original shape
if (channel_wise or token_wise) and len(shape) != 2:
output_tensor = output_tensor.view(shape[:-1] + (-1, ))
if return_scaled_input_for_bwd:
scaled_input = scaled_input.view(shape[:-1] + (-1, ))

if return_scaled_input_for_bwd:
return output_tensor, scaled_input
return output_tensor
Loading