diff --git a/gpt_builders.py b/gpt_builders.py index dfe41f7b88e..f3ecdc2a834 100644 --- a/gpt_builders.py +++ b/gpt_builders.py @@ -81,7 +81,12 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ else: # Define the decoder block spec decoder_layer_specs = get_gpt_decoder_layer_specs( - config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage + config, + use_transformer_engine=use_te, + normalization=args.normalization, + qk_l2_norm=args.qk_l2_norm, + use_bitnet=args.use_bitnet, + vp_stage=vp_stage, ) transformer_layer_spec_for_mtp = decoder_layer_specs[-1] # Use spec of the last layer in decoder block as spec of the transformer layer in MTP @@ -116,12 +121,12 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_ def _get_transformer_layer_spec(use_te, config): """Get transformer layer specification based on configuration. - + Args: use_te (bool): Whether to use Transformer Engine args: Training arguments config: Model configuration - + Returns: transformer_layer_spec: The transformer layer specification """ @@ -151,6 +156,7 @@ def _get_transformer_layer_spec(use_te, config): args.moe_grouped_gemm, args.qk_layernorm, args.multi_latent_attention, + args.use_bitnet, args.experimental_attention_variant, moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm, normalization=args.normalization, diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 974e33f88e8..1ea838d9f22 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -4,6 +4,11 @@ from typing import Optional, Union from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ( + BitNetColumnParallelLinear, + BitNetRowParallelLinear, +) +from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.models.backends import ( BackendSpecProvider, InferenceSpecProvider, @@ -311,6 +316,7 @@ def get_gpt_layer_local_spec( moe_grouped_gemm: Optional[bool] = False, qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, + use_bitnet: Optional[bool] = False, fp8: Optional[str] = None, # pylint: disable=unused-argument moe_use_legacy_grouped_gemm: Optional[bool] = False, normalization: Optional[str] = None, @@ -327,6 +333,7 @@ def get_gpt_layer_local_spec( moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. multi_latent_attention (bool, optional): To use MLA. Defaults to False. + use_bitnet (bool, optional): To use BitNet linear layers. Defaults to False. fp8 (str, optional): Deprecated. For temporary Nemo compatibility. moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. Defaults to False. @@ -361,6 +368,7 @@ def get_gpt_layer_local_spec( mlp = get_mlp_module_spec_for_backend( backend=backend, + use_bitnet=use_bitnet, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, @@ -393,6 +401,36 @@ def get_gpt_layer_local_spec( mlp_bda=get_bias_dropout_add, ), ) + elif use_bitnet: + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=LNImpl, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=BitNetColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=BitNetRowParallelLinear, + q_layernorm=( + L2Norm if qk_l2_norm else (LNImpl if qk_layernorm else IdentityOp) + ), + k_layernorm=( + L2Norm if qk_l2_norm else (LNImpl if qk_layernorm else IdentityOp) + ), + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=LNImpl, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + sharded_state_dict_keys_map={ + 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', + 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + }, + ), + ) else: return ModuleSpec( module=TransformerLayer, @@ -427,6 +465,7 @@ def get_gpt_layer_local_spec( def _get_mlp_module_spec( use_te: Optional[bool] = True, + use_bitnet: Optional[bool] = False, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, fp8: Optional[str] = None, # pylint: disable=unused-argument @@ -439,6 +478,7 @@ def _get_mlp_module_spec( return get_mlp_module_spec( use_te=use_te, + use_bitnet=use_bitnet, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8, @@ -448,6 +488,7 @@ def _get_mlp_module_spec( def get_mlp_module_spec( use_te: Optional[bool] = True, + use_bitnet: Optional[bool] = False, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, fp8: Optional[str] = None, # pylint: disable=unused-argument @@ -472,6 +513,7 @@ def get_mlp_module_spec( return get_mlp_module_spec_for_backend( backend=TESpecProvider() if use_te else LocalSpecProvider(), + use_bitnet=use_bitnet, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, @@ -481,6 +523,7 @@ def get_mlp_module_spec( def get_mlp_module_spec_for_backend( backend: BackendSpecProvider, + use_bitnet: Optional[bool] = False, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, moe_use_legacy_grouped_gemm: Optional[bool] = False, @@ -491,8 +534,17 @@ def get_mlp_module_spec_for_backend( linear_fc2 = backend.row_parallel_linear() activation_func = backend.activation_func() if use_te_activation_func else None - - if num_experts is None: + + if use_bitnet: + # Bitnet MLP only w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=BitNetColumnParallelLinear, + linear_fc2= BitNetRowParallelLinear, + ), + ) + elif num_experts is None: # Dense MLP w/ or w/o TE modules. module = TEFusedMLP if use_te_op_fuser else MLP if backend.fuse_layernorm_and_linear(): @@ -522,6 +574,7 @@ def get_gpt_decoder_layer_specs( use_transformer_engine: bool, normalization: Optional[str] = None, qk_l2_norm: Optional[bool] = False, + use_bitnet: Optional[bool] = False, vp_stage: Optional[int] = None, pp_rank: Optional[int] = None, ) -> TransformerBlockSubmodules: @@ -559,6 +612,7 @@ def get_gpt_decoder_layer_specs( moe_grouped_gemm=False, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, + use_bitnet=use_bitnet, moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, normalization=normalization, qk_l2_norm=qk_l2_norm, @@ -571,6 +625,7 @@ def get_gpt_decoder_layer_specs( moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, + use_bitnet=use_bitnet, moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, normalization=normalization, qk_l2_norm=qk_l2_norm, diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 8cd43dd9b6c..4ecf8eb3806 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -11,7 +11,15 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter - +# Import BitNet quantization kernels +try: + from onebitllms import activation_quant_triton, weight_quant_triton + HAVE_ONEBIT = True +except ImportError: + HAVE_ONEBIT = False + activation_quant_triton = None + weight_quant_triton = None + from megatron.core.model_parallel_config import ModelParallelConfig from megatron.core.parallel_state import ( get_global_memory_buffer, @@ -1319,3 +1327,141 @@ def __repr__(self): f"{type(self).__name__}(in_features={self.input_size}, " f"out_features={self.output_size}, bias={use_bias}, TP={tp})" ) + +class BitNetColumnParallelLinear(ColumnParallelLinear): + """ + BitNet-enabled Column Parallel Linear Layer. + + Extends ColumnParallelLinear with BitNet quantization while maintaining + all tensor parallelism functionality. BitNet quantization is applied + during training by overriding _forward_impl to quantize activations and weights. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool = True, + gather_output: bool = False, + stride: int = 1, + keep_master_weight_for_test: bool = False, + skip_bias_add: bool = False, + skip_weight_param_allocation: bool = False, + embedding_activation_buffer: Optional[List[torch.Tensor]] = None, + grad_output_buffer: Optional[List[torch.Tensor]] = None, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + disable_grad_reduce: bool = False, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + ): + """ + Initialize BitNet Column Parallel Linear layer. + + BitNet quantization is automatically enabled when onebitllms is installed. + """ + super().__init__( + input_size=input_size, + output_size=output_size, + config=config, + init_method=init_method, + bias=bias, + gather_output=gather_output, + stride=stride, + keep_master_weight_for_test=keep_master_weight_for_test, + skip_bias_add=skip_bias_add, + skip_weight_param_allocation=skip_weight_param_allocation, + embedding_activation_buffer=embedding_activation_buffer, + grad_output_buffer=grad_output_buffer, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + disable_grad_reduce=disable_grad_reduce, + tp_group=tp_group, + ) + + if not HAVE_ONEBIT: + raise ImportError( + "BitNet requires onebitllms to be installed. " + "Install with: pip install onebitllms" + ) + + def _forward_impl(self, input, weight, *args, **kwargs): + """ + Override parent's _forward_impl to apply BitNet quantization. + reference: https://github.com/tiiuae/onebitllms/blob/main/src/onebitllms/layers/bitnet.py + + Uses the Straight-Through Estimator (STE) pattern from onebitllms: + x_quant = x + (quant(x) - x).detach() + """ + # Apply STE quantization pattern (matches onebitllms BitNetLinear.forward) + input_quantized = input + (activation_quant_triton(input) - input).detach() + weight_quantized = weight + (weight_quant_triton(weight) - weight).detach() + + return super()._forward_impl(input_quantized, weight_quantized, *args, **kwargs) + + +class BitNetRowParallelLinear(RowParallelLinear): + """ + BitNet-enabled Row Parallel Linear Layer. + + Extends RowParallelLinear with BitNet quantization while maintaining + all tensor parallelism functionality. BitNet quantization is applied + during training by overriding _forward_impl to quantize inputs and weights. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: ModelParallelConfig, + init_method: Callable, + bias: bool, + input_is_parallel: bool, + skip_bias_add: bool, + stride: int = 1, + keep_master_weight_for_test: bool = False, + is_expert: bool = False, + tp_comm_buffer_name: str = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + ): + """ + Initialize BitNet Row Parallel Linear layer. + BitNet quantization is automatically enabled when onebitllms is installed. + """ + super().__init__( + input_size=input_size, + output_size=output_size, + config=config, + init_method=init_method, + bias=bias, + input_is_parallel=input_is_parallel, + skip_bias_add=skip_bias_add, + stride=stride, + keep_master_weight_for_test=keep_master_weight_for_test, + is_expert=is_expert, + tp_comm_buffer_name=tp_comm_buffer_name, + tp_group=tp_group, + ) + + if not HAVE_ONEBIT: + raise ImportError( + "BitNet requires onebitllms to be installed. " + "Install with: pip install onebitllms" + ) + + def _forward_impl(self, input, weight, *args, **kwargs): + """ + Override parent's _forward_impl to apply BitNet quantization. + reference: https://github.com/tiiuae/onebitllms/blob/main/src/onebitllms/layers/bitnet.py + + Uses the Straight-Through Estimator (STE) pattern from onebitllms: + x_quant = x + (quant(x) - x).detach() + """ + # Apply STE quantization pattern (matches onebitllms BitNetLinear.forward) + input_quantized = input + (activation_quant_triton(input) - input).detach() + weight_quantized = weight + (weight_quant_triton(weight) - weight).detach() + + return super()._forward_impl(input_quantized, weight_quantized, *args, **kwargs) \ No newline at end of file diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index ae3b95985c4..64ac32a85d0 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1105,7 +1105,12 @@ def validate_args(args, defaults={}): args.moe_ffn_hidden_size = args.ffn_hidden_size if args.rank == 0: print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") - + + # BitNet validation + if args.use_bitnet: + assert args.transformer_impl == 'local', \ + 'BitNet training requires --transformer-impl local , we will use onebitllm triton kernels for BitNet fwd and bwd pass ' \ + # Context parallel if args.context_parallel_size > 1: assert not args.use_legacy_models, "Context parallelism is not supported in legacy models." @@ -2490,6 +2495,8 @@ def _add_training_args(parser): 'This will significantly affect speed of training and inference as the kernels are not full optimized.') group.add_argument('--disable-jit-fuser', action='store_true', help='Disable the JIT fuser.') + group.add_argument('--use-bitnet', action='store_true', + help='Pretraining bitnet') return parser