diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index b444b176..0f3daa50 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -21,7 +21,7 @@ jobs: misspell -error . cpp: name: CPP code lint - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v2 @@ -33,7 +33,7 @@ jobs: make cpplint md: name: Markdown code lint - runs-on: ubuntu-20.04 + runs-on: ubuntu-22.04 steps: - name: Checkout uses: actions/checkout@v2 diff --git a/Makefile b/Makefile index a37f66fb..3565b35c 100644 --- a/Makefile +++ b/Makefile @@ -24,4 +24,5 @@ lint: cpplint mdlint postinstall: cd msamp/operators/dist_op && bash build.sh && cd - cd msamp/operators/arithmetic && pip install -v . && cd - + cd msamp/operators/fp4_quantize && pip install -v . && cd - cd msamp/optim && pip install -v . && cd - diff --git a/msamp/megatron/layers.py b/msamp/megatron/layers.py index d28d0c2a..0c7ac343 100644 --- a/msamp/megatron/layers.py +++ b/msamp/megatron/layers.py @@ -13,6 +13,19 @@ from msamp.common.tensor import ScalingTensor from msamp.operators.gemm import Gemm +import os + +""" Below are the environment variables to control the FP4 quantization behavior, based on https://arxiv.org/abs/2501.17116 +Using 'MSAMP_USE_WEIGHT_SIMULATE_FP4' to control if weight quantization is used. +Using 'MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR' to control if DGE (Differentiable Gradient Estimator) is used. +Using 'MSAMP_USE_ACTIVATION_SIMULATE_FP4' to control if activation quantization is used. +""" +MSAMP_USE_WEIGHT_SIMULATE_FP4 = bool(int(os.getenv('MSAMP_USE_WEIGHT_SIMULATE_FP4', 0))) +MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR = bool(int(os.getenv('MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR', 0))) +MSAMP_USE_ACTIVATION_SIMULATE_FP4 = bool(int(os.getenv('MSAMP_USE_ACTIVATION_SIMULATE_FP4', 0))) + +from msamp.operators.fp4_quantize import FP4_QUANTIZER + class FP8LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): """A linear function with FP8 support, grad accumulation and async communication.""" @@ -50,19 +63,32 @@ def forward(ctx, input, weight, bias, gradient_accumulation_fusion, async_grad_a old_meta_group = input_meta.group input_meta.group = tp_group - input_fp8 = input.cast(Dtypes.kfloat8_e4m3, meta=input_meta, sync=sequence_parallel) + if MSAMP_USE_ACTIVATION_SIMULATE_FP4: + fp4_input_in_float = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(input.bfloat16(), format='e2m1', nan_existed=False, token_wise=True, outlier_clip=True, clip_threshold=0.99) + input_fp8 = fp4_input_in_float.cast(Dtypes.kfloat8_e4m3, meta=input_meta, sync=sequence_parallel) + else: + input_fp8 = input.cast(Dtypes.kfloat8_e4m3, meta=input_meta, sync=sequence_parallel) input_meta.group = old_meta_group input_fp8.requires_grad = input.requires_grad input = input_fp8.value - weight_fp8 = weight.cast(Dtypes.kfloat8_e4m3) + if MSAMP_USE_WEIGHT_SIMULATE_FP4: + if MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR: + fp4_weight_in_float, scaled_w = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(weight.bfloat16(), format='e2m1', nan_existed=False, channel_wise=True, return_scaled_input_for_bwd=True) + else: + fp4_weight_in_float = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(weight.bfloat16(), format='e2m1', nan_existed=False, channel_wise=True) + weight_fp8 = fp4_weight_in_float.cast(Dtypes.kfloat8_e4m3) + else: + weight_fp8 = weight.cast(Dtypes.kfloat8_e4m3) weight_fp8.requires_grad = weight.requires_grad # save tensors ctx.input_fp8 = input_fp8 ctx.weight_fp8 = weight_fp8 ctx.weight = weight + if MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR: + ctx.save_for_backward(scaled_w) dim_size = list(input.size()) if sequence_parallel: @@ -175,6 +201,9 @@ def backward(ctx, grad_output): wgrad_qtype, use_split_accumulator=True, ) + if MSAMP_USE_WEIGHT_DIFFERENTIABLE_GRADIENT_ESTIMATOR: + scaled_w = ctx.saved_tensors[0] + grad_weight.mul_(FP4_QUANTIZER.apply_DGE_item(scaled_w)) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/msamp/operators/fp4_quantize/__init__.py b/msamp/operators/fp4_quantize/__init__.py new file mode 100644 index 00000000..569239f9 --- /dev/null +++ b/msamp/operators/fp4_quantize/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Exposes the interface of MS-AMP FP4 Quantize module.""" + +from msamp.operators.fp4_quantize.fp4_quantize import FP4_QUANTIZER + +__all__ = ['FP4_QUANTIZER'] diff --git a/msamp/operators/fp4_quantize/fp4_quantize.py b/msamp/operators/fp4_quantize/fp4_quantize.py new file mode 100644 index 00000000..42e22547 --- /dev/null +++ b/msamp/operators/fp4_quantize/fp4_quantize.py @@ -0,0 +1,186 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""fp4_quantize module. +Algorithm based on https://arxiv.org/abs/2501.17116. Provided basic python interface of FP4 quantization +and DGE (Differentiable Gradient Estimator) for more accurate gradient update in FP4 training. +""" + +import torch +from typing import Literal + +from msamp.common.tensor import ScalingMeta + +import msamp_quantize + + +class FP4_QUANTIZER: + """FP4 Quantization operator. Algorithm based on https://arxiv.org/abs/2501.17116.""" + @staticmethod + def apply_DGE_item( + input_tensor: torch.Tensor, + k: float = 5.0, + power_clamp_max: float = 3.0 + ) -> torch.Tensor: + """ + Apply DGE item to input tensor. Note that this function is fixed to E2M1 format with no NaN. + DGE: Abbreviation of the method 'Differentiable Gradient Estimator' for more accurate gradient update in FP4 training. + + Args: + input (torch.Tensor): input tensor. + k (float): parameter k to determine the sharpness of the differentiable quantization estimator. + power_clamp_max (float): parameter power_clamp_max to restrict the amplitude of the estimated gradient. + Returns: + torch.Tensor: output tensor. + """ + if not (input_tensor.is_cuda and input_tensor.is_contiguous): + raise ValueError('The input tensor is not in cuda memory or contiguous.') + if not (input_tensor.dtype == torch.bfloat16): + raise ValueError('The input tensor is not in bfloat16.') + + output_tensor = torch.zeros_like(input_tensor) + msamp_quantize.launch_differentiable_quantize_derivative(input_tensor, output_tensor, k, power_clamp_max, torch.numel(input_tensor)) + return output_tensor + + + @staticmethod + def _apply_quantile_clipping( + input: torch.Tensor, + clip_threshold: float = 0.99, + channel_wise: bool = False, + token_wise: bool = False, + return_residual: bool = False, + ) -> tuple: + ''' + Apply quantile clipping to the input tensor. + + Args: + input (torch.Tensor): input tensor. + clip_threshold (float): threshold for quantile clipping. Default is 0.99. + channel_wise (bool): whether to apply clipping through channel dimension. Default is False. + token_wise (bool): whether to apply clipping through token dimension. Default is False. + return_residual (bool): whether to return the residual. Default is False. + Returns: + tuple: output tensor and residual tensor (if return_residual is True). + ''' + float_input = input.float() if input.dtype != torch.float32 else input + + if channel_wise: + sorted_tensor = torch.sort(input, dim=0).values + lower_index = int((1 - clip_threshold) * sorted_tensor.size(0)) + upper_index = int(clip_threshold * sorted_tensor.size(0)) + + lower_bound = sorted_tensor[lower_index:lower_index+1, :] + upper_bound = sorted_tensor[upper_index:upper_index+1, :] + + output = torch.clamp(input, min=lower_bound, max=upper_bound) + + elif token_wise: + sorted_tensor = torch.sort(input, dim=1).values + lower_index = int((1 - clip_threshold) * sorted_tensor.size(1)) + upper_index = int(clip_threshold * sorted_tensor.size(1)) + + lower_bound = sorted_tensor[:, lower_index:lower_index+1] + upper_bound = sorted_tensor[:, upper_index:upper_index+1] + + output = torch.clamp(input, min=lower_bound, max=upper_bound) + + else: + sorted_tensor = torch.sort(float_input.view(-1))[0] + lower_index = int((1 - clip_threshold) * sorted_tensor.size(0)) + upper_index = int(clip_threshold * sorted_tensor.size(0)) + + lower_bound = sorted_tensor[lower_index:lower_index+1] + upper_bound = sorted_tensor[upper_index:upper_index+1] + + output = torch.clamp(input, min=lower_bound, max=upper_bound) + + output = output.to(input.dtype) + if return_residual: + return output, input - output + else: + return output, None + + + @staticmethod + def quantize_simulate_fp4_in_bf16( + input_tensor: torch.Tensor, + format: Literal['e2m1', 'e1m2'] = 'e1m2', + nan_existed: bool = False, + channel_wise: bool = False, + token_wise: bool = False, + outlier_clip: bool = False, + clip_threshold: float = 0.99, + residual_compensation: bool = False, + return_scaled_input_for_bwd: bool = False, + ) -> torch.Tensor: + """ + Quantize high precision tensor to FP4 tensor. + + Args: + input_tensor (torch.Tensor): high precision tensor to quantize. Note that the input tensor should be in cuda memory and bfloat16 dtype. + format (Literal['e2m1', 'e1m2']): format of the quantized tensor. Default is 'e1m2'. + nan_existed (bool): whether NaN value exists in the input tensor. Default is False. + channel_wise (bool): whether to quantize the input tensor through channel dimension. Default is False. + token_wise (bool): whether to quantize the input tensor through token dimension. Default is False. + outlier_clip (bool): whether to apply outlier clipping to the input tensor. Default is False. + clip_threshold (float): threshold for outlier clipping. Default is 0.99. + residual_compensation (bool): whether to add residual back to the quantized tensor. Default is False. + return_scaled_input_for_bwd (bool): whether to return scaled input tensor for backward computation. Default is False. + + Note: param 'nan_existed' claimed but needn't to be used (to keep API consistent with other functions). + Returns: + torch.Tensor: simulted FP4-quantied tensor, but still in bfloat16 dtype. + """ + if not (input_tensor.is_cuda and input_tensor.is_contiguous): + raise ValueError('The input tensor is not in cuda memory or contiguous.') + if not (input_tensor.dtype == torch.bfloat16): + raise ValueError('The input tensor is not in bfloat16.') + + # handle tensor shape for channel_wise or token_wise quantization + shape = input_tensor.shape + assert not (channel_wise and token_wise), f"channel_wise and token_wise cannot be True at the same time." + if (channel_wise or token_wise) and len(shape) != 2: + dim = shape[-1] + input_tensor = input_tensor.reshape(-1, dim) + + # handle outlier clipping + if outlier_clip: + input_tensor, residual = FP4_QUANTIZER._apply_quantile_clipping(input_tensor, clip_threshold, channel_wise, token_wise, return_residual=residual_compensation) + + # get amax + if channel_wise: + amax = input_tensor.abs().max(dim=0, keepdim=True)[0] # channel-wise max value + scale = torch.ones((1, 1), dtype=input_tensor.dtype, device='cuda') # 2-D tensor shape + elif token_wise: + amax = input_tensor.abs().max(dim=1, keepdim=True)[0] # token-wise max value + scale = torch.ones((1, 1), dtype=input_tensor.dtype, device='cuda') # 2-D tensor shape + else: + amax = input_tensor.abs().max() + scale = torch.ones((), dtype=input_tensor.dtype, device='cuda') + # compute scaling factor + fp_max = 6.0 if format == 'e2m1' else 7.0 # Fixed. For e1m2, actually it is 3.5, but we *2 for directly round() + margin = 0 + sf = ScalingMeta.compute_scaling_factor(amax, scale, fp_max, margin) + + # quantize + scaled_input = input_tensor * sf # this * operation can handle matrix-tensor broadcasting. For example, (3, 4) * (4,) -> (3, 4) + if format == 'e2m1': + output_tensor = torch.zeros_like(scaled_input) + msamp_quantize.quantize_bf16(scaled_input, output_tensor, torch.numel(scaled_input)) + else: + output_tensor = torch.round(scaled_input) + output_tensor.div_(sf) # this .div_() method can also handle matrix-tensor broadcasting + if residual_compensation: + output_tensor = output_tensor + residual + output_tensor.requires_grad = input_tensor.requires_grad + + # reshape output tensor to original shape + if (channel_wise or token_wise) and len(shape) != 2: + output_tensor = output_tensor.view(shape[:-1] + (-1, )) + if return_scaled_input_for_bwd: + scaled_input = scaled_input.view(shape[:-1] + (-1, )) + + if return_scaled_input_for_bwd: + return output_tensor, scaled_input + return output_tensor diff --git a/msamp/operators/fp4_quantize/quantize.cu b/msamp/operators/fp4_quantize/quantize.cu new file mode 100644 index 00000000..18f3f6f3 --- /dev/null +++ b/msamp/operators/fp4_quantize/quantize.cu @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include "../../common/include/common.h" + +// Algorithm based on https://arxiv.org/abs/2501.17116 +// Provided CUDA kernel implementation of Simulated FP4 Quantization and DGE(Differentiable Gradient Estimator) + +__global__ void quantize_bf16_kernel(const __nv_bfloat16* x, __nv_bfloat16* output, int x_size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < x_size) { + __nv_bfloat16 value = x[idx]; + __nv_bfloat16 closest; + + if (__hlt(value, __float2bfloat16(-5.0f))) { + closest = __float2bfloat16(-6.0f); + } else if (__hlt(value, __float2bfloat16(-3.5f))) { + closest = __float2bfloat16(-4.0f); + } else if (__hlt(value, __float2bfloat16(-2.5f))) { + closest = __float2bfloat16(-3.0f); + } else if (__hlt(value, __float2bfloat16(-1.75f))) { + closest = __float2bfloat16(-2.0f); + } else if (__hlt(value, __float2bfloat16(-1.25f))) { + closest = __float2bfloat16(-1.5f); + } else if (__hlt(value, __float2bfloat16(-0.75f))) { + closest = __float2bfloat16(-1.0f); + } else if (__hlt(value, __float2bfloat16(-0.25f))) { + closest = __float2bfloat16(-0.5f); + } else if (__hlt(value, __float2bfloat16(0.25f))) { + closest = __float2bfloat16(0.0f); + } else if (__hlt(value, __float2bfloat16(0.75f))) { + closest = __float2bfloat16(0.5f); + } else if (__hlt(value, __float2bfloat16(1.25f))) { + closest = __float2bfloat16(1.0f); + } else if (__hlt(value, __float2bfloat16(1.75f))) { + closest = __float2bfloat16(1.5f); + } else if (__hlt(value, __float2bfloat16(2.5f))) { + closest = __float2bfloat16(2.0f); + } else if (__hlt(value, __float2bfloat16(3.5f))) { + closest = __float2bfloat16(3.0f); + } else if (__hlt(value, __float2bfloat16(5.0f))) { + closest = __float2bfloat16(4.0f); + } else { + closest = __float2bfloat16(6.0f); + } + + output[idx] = closest; + } +} + +void quantize_bf16(at::Tensor input, at::Tensor output, int size) { + + const __nv_bfloat16* input_data = reinterpret_cast(input.data_ptr()); + __nv_bfloat16* output_data = reinterpret_cast<__nv_bfloat16*>(output.data_ptr()); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(size); + const int blocks = (size + threadsPerBlock - 1) / threadsPerBlock; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + quantize_bf16_kernel<<>>(input_data, output_data, size); +} + + +__device__ float power_derivative(float x, float delta, float k, float power_clamp_max) { + float abs_term = fabsf(2.0f * x / delta - 1.0f); + return fminf(powf(abs_term, 1.0f / k - 1.0f) / k, power_clamp_max); +} + + +// for fixed E2M1_no_NaN section: [-6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] +__global__ void differentiable_quantize_derivative( + const __nv_bfloat16* input, __nv_bfloat16* output, + float k, float power_clamp_max, int n +) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= n) return; + + float x = __bfloat162float(input[idx]); + float dy = 0.0f; + + if (x < -4.0f) { + dy = power_derivative(x + 6.0f, 2.0f, k, power_clamp_max); + } else if (x >= -4.0f && x < -3.0f) { + dy = power_derivative(x + 4.0f, 1.0f, k, power_clamp_max); + } else if (x >= -3.0f && x < -2.0f) { + dy = power_derivative(x + 3.0f, 1.0f, k, power_clamp_max); + } else if (x >= -2.0f && x < -1.5f) { + dy = power_derivative(x + 2.0f, 0.5f, k, power_clamp_max); + } else if (x >= -1.5f && x < -1.0f) { + dy = power_derivative(x + 1.5f, 0.5f, k, power_clamp_max); + } else if (x >= -1.0f && x < -0.5f) { + dy = power_derivative(x + 1.0f, 0.5f, k, power_clamp_max); + } else if (x >= -0.5f && x < 0.0f) { + dy = power_derivative(x + 0.5f, 0.5f, k, power_clamp_max); + } else if (x >= 0.0f && x < 0.5f) { + dy = power_derivative(x, 0.5f, k, power_clamp_max); + } else if (x >= 0.5f && x < 1.0f) { + dy = power_derivative(x - 0.5f, 0.5f, k, power_clamp_max); + } else if (x >= 1.0f && x < 1.5f) { + dy = power_derivative(x - 1.0f, 0.5f, k, power_clamp_max); + } else if (x >= 1.5f && x < 2.0f) { + dy = power_derivative(x - 1.5f, 0.5f, k, power_clamp_max); + } else if (x >= 2.0f && x < 3.0f) { + dy = power_derivative(x - 2.0f, 1.0f, k, power_clamp_max); + } else if (x >= 3.0f && x < 4.0f) { + dy = power_derivative(x - 3.0f, 1.0f, k, power_clamp_max); + } else if (x >= 4.0f && x <= 6.0f) { + dy = power_derivative(x - 4.0f, 2.0f, k, power_clamp_max); + } + + output[idx] = __float2bfloat16(dy); +} + + +void launch_differentiable_quantize_derivative( + at::Tensor input, at::Tensor output, + float k, float power_clamp_max, int size +) { + const __nv_bfloat16* input_data = reinterpret_cast(input.data_ptr()); + __nv_bfloat16* output_data = reinterpret_cast<__nv_bfloat16*>(output.data_ptr()); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(size); + const int blocks = (size + threadsPerBlock - 1) / threadsPerBlock; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + differentiable_quantize_derivative<<>>(input_data, output_data, k, power_clamp_max, size); +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("quantize_bf16", &quantize_bf16, "Simulated Quantize FP4 Function in BF16 Format"); + m.def("launch_differentiable_quantize_derivative", &launch_differentiable_quantize_derivative, "Differentiable Quantize Derivative Function"); +} diff --git a/msamp/operators/fp4_quantize/setup.py b/msamp/operators/fp4_quantize/setup.py new file mode 100644 index 00000000..6028ceee --- /dev/null +++ b/msamp/operators/fp4_quantize/setup.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""The setuptools based setup module.""" + +from setuptools import setup +from torch.utils import cpp_extension + +ext_t = cpp_extension.CUDAExtension +ext_fnames = ['quantize.cu'] +define_macros = [] +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], nvcc=['-O3']) + +define_macros.append(('WITH_CUDA', None)) + +setup( + name='msamp_quantize', + version='0.0.1', + ext_modules=[ + ext_t( + 'msamp_quantize', + ext_fnames, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ], + cmdclass={'build_ext': cpp_extension.BuildExtension} +) diff --git a/tests/operators/test_fp4_quantize.py b/tests/operators/test_fp4_quantize.py new file mode 100644 index 00000000..2ba7b81a --- /dev/null +++ b/tests/operators/test_fp4_quantize.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Tests for FP4 quantization operator.""" + +import itertools +import unittest + +import torch + +from tests.helper import decorator +from msamp.operators.fp4_quantize import FP4_QUANTIZER + + +class FP4QuantTestCase(unittest.TestCase): + '''A class for FP4 quantization test cases.''' + @decorator.cuda_test + def test_DGE(self): + '''Check the DGE item.''' + total_points = 20 + x_values = torch.linspace(-6.0, 6.0, total_points).to(torch.bfloat16).cuda() + excepted_y_values = torch.tensor( + [0.2002, 0.4375, 0.6055, 0.2168, 1.8359, 0.2695, 0.3027, 0.2695, 0.2402, 0.5781, + 0.5781, 0.2402, 0.2695, 0.3027, 0.2695, 1.8359, 0.2168, 0.6055, 0.4375, 0.2002], dtype=torch.bfloat16).cuda() + differentiable_quantized_y_derivative = FP4_QUANTIZER.apply_DGE_item(x_values) + self.assertTrue(torch.allclose(differentiable_quantized_y_derivative, excepted_y_values)) + + + @decorator.cuda_test + def test_fp4_quant(self): + '''Check the quantization of input tensor.''' + input_tensor = torch.tensor([[[0.001, 0.048, 0.0997], [0.1503, 0.2002, 0.2497], [0.2974, 0.30699, 0.4001]]], dtype=torch.bfloat16).cuda() + target_tensor = torch.tensor([[[0.0, 0.0625, 0.125], [0.125, 0.1875, 0.25], [0.25, 0.25, 0.375]]], dtype=torch.bfloat16).cuda() + output_tensor = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(input_tensor, format='e2m1', nan_existed=False) + self.assertTrue(torch.allclose(output_tensor, target_tensor)) + + input_tensor = torch.tensor( + [ [ [-0.01, 0.48, -9.67], + [1.623, -2.222, 24.67], ], + [ [-2.874, 3.699, -34.57], + [0.85, -1.343, 18.88], ] + ], dtype=torch.bfloat16).cuda() # channel-wise outlier. shape: (2, 2, 3) + target_tensor = torch.tensor( + [ [ [ 0.0, 0.5, -8.0], + [1.5, -2.0, 24.0], ], + [ [-3.0, 4.0, -32.0], + [0.75, -1.5, 16.0], ] + ], dtype=torch.bfloat16).cuda() + output_tensor = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(input_tensor, format='e2m1', nan_existed=False, channel_wise=True) + self.assertTrue(torch.allclose(output_tensor, target_tensor)) + + output_tensor = FP4_QUANTIZER.quantize_simulate_fp4_in_bf16(input_tensor.view(-1, 3).T, format='e2m1', nan_existed=False, token_wise=True) # token-wise outlier. + self.assertTrue(torch.allclose(output_tensor, target_tensor.view(-1, 3).T)) +