Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions gpt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_
else:
# Define the decoder block spec
decoder_layer_specs = get_gpt_decoder_layer_specs(
config, use_transformer_engine=use_te, normalization=args.normalization, qk_l2_norm=args.qk_l2_norm, vp_stage=vp_stage
config,
use_transformer_engine=use_te,
normalization=args.normalization,
qk_l2_norm=args.qk_l2_norm,
use_bitnet=args.use_bitnet,
vp_stage=vp_stage,
)
transformer_layer_spec_for_mtp = decoder_layer_specs[-1]
# Use spec of the last layer in decoder block as spec of the transformer layer in MTP
Expand Down Expand Up @@ -116,12 +121,12 @@ def gpt_builder(args, pre_process, post_process, vp_stage=None, config=None, pg_

def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.

Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration

Returns:
transformer_layer_spec: The transformer layer specification
"""
Expand Down Expand Up @@ -151,6 +156,7 @@ def _get_transformer_layer_spec(use_te, config):
args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.use_bitnet,
args.experimental_attention_variant,
moe_use_legacy_grouped_gemm=args.moe_use_legacy_grouped_gemm,
normalization=args.normalization,
Expand Down
59 changes: 57 additions & 2 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from typing import Optional, Union

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.tensor_parallel.layers import (
BitNetColumnParallelLinear,
BitNetRowParallelLinear,
)
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.models.backends import (
BackendSpecProvider,
InferenceSpecProvider,
Expand Down Expand Up @@ -311,6 +316,7 @@ def get_gpt_layer_local_spec(
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
use_bitnet: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-argument
moe_use_legacy_grouped_gemm: Optional[bool] = False,
normalization: Optional[str] = None,
Expand All @@ -327,6 +333,7 @@ def get_gpt_layer_local_spec(
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
multi_latent_attention (bool, optional): To use MLA. Defaults to False.
use_bitnet (bool, optional): To use BitNet linear layers. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Expand Down Expand Up @@ -361,6 +368,7 @@ def get_gpt_layer_local_spec(

mlp = get_mlp_module_spec_for_backend(
backend=backend,
use_bitnet=use_bitnet,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
Expand Down Expand Up @@ -393,6 +401,36 @@ def get_gpt_layer_local_spec(
mlp_bda=get_bias_dropout_add,
),
)
elif use_bitnet:
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
input_layernorm=LNImpl,
self_attention=ModuleSpec(
module=SelfAttention,
params={"attn_mask_type": AttnMaskType.causal},
submodules=SelfAttentionSubmodules(
linear_qkv=BitNetColumnParallelLinear,
core_attention=DotProductAttention,
linear_proj=BitNetRowParallelLinear,
q_layernorm=(
L2Norm if qk_l2_norm else (LNImpl if qk_layernorm else IdentityOp)
),
k_layernorm=(
L2Norm if qk_l2_norm else (LNImpl if qk_layernorm else IdentityOp)
),
),
),
self_attn_bda=get_bias_dropout_add,
pre_mlp_layernorm=LNImpl,
mlp=mlp,
mlp_bda=get_bias_dropout_add,
sharded_state_dict_keys_map={
'input_layernorm.': 'self_attention.linear_qkv.layer_norm_',
'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_',
},
),
)
else:
return ModuleSpec(
module=TransformerLayer,
Expand Down Expand Up @@ -427,6 +465,7 @@ def get_gpt_layer_local_spec(

def _get_mlp_module_spec(
use_te: Optional[bool] = True,
use_bitnet: Optional[bool] = False,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-argument
Expand All @@ -439,6 +478,7 @@ def _get_mlp_module_spec(

return get_mlp_module_spec(
use_te=use_te,
use_bitnet=use_bitnet,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
fp8=fp8,
Expand All @@ -448,6 +488,7 @@ def _get_mlp_module_spec(

def get_mlp_module_spec(
use_te: Optional[bool] = True,
use_bitnet: Optional[bool] = False,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-argument
Expand All @@ -472,6 +513,7 @@ def get_mlp_module_spec(

return get_mlp_module_spec_for_backend(
backend=TESpecProvider() if use_te else LocalSpecProvider(),
use_bitnet=use_bitnet,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
Expand All @@ -481,6 +523,7 @@ def get_mlp_module_spec(

def get_mlp_module_spec_for_backend(
backend: BackendSpecProvider,
use_bitnet: Optional[bool] = False,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
Expand All @@ -491,8 +534,17 @@ def get_mlp_module_spec_for_backend(

linear_fc2 = backend.row_parallel_linear()
activation_func = backend.activation_func() if use_te_activation_func else None

if num_experts is None:

if use_bitnet:
# Bitnet MLP only w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=BitNetColumnParallelLinear,
linear_fc2= BitNetRowParallelLinear,
),
)
elif num_experts is None:
# Dense MLP w/ or w/o TE modules.
module = TEFusedMLP if use_te_op_fuser else MLP
if backend.fuse_layernorm_and_linear():
Expand Down Expand Up @@ -522,6 +574,7 @@ def get_gpt_decoder_layer_specs(
use_transformer_engine: bool,
normalization: Optional[str] = None,
qk_l2_norm: Optional[bool] = False,
use_bitnet: Optional[bool] = False,
vp_stage: Optional[int] = None,
pp_rank: Optional[int] = None,
) -> TransformerBlockSubmodules:
Expand Down Expand Up @@ -559,6 +612,7 @@ def get_gpt_decoder_layer_specs(
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
use_bitnet=use_bitnet,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
normalization=normalization,
qk_l2_norm=qk_l2_norm,
Expand All @@ -571,6 +625,7 @@ def get_gpt_decoder_layer_specs(
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
use_bitnet=use_bitnet,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
normalization=normalization,
qk_l2_norm=qk_l2_norm,
Expand Down
148 changes: 147 additions & 1 deletion megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

# Import BitNet quantization kernels
try:
from onebitllms import activation_quant_triton, weight_quant_triton
HAVE_ONEBIT = True
except ImportError:
HAVE_ONEBIT = False
activation_quant_triton = None
weight_quant_triton = None

from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.parallel_state import (
get_global_memory_buffer,
Expand Down Expand Up @@ -1319,3 +1327,141 @@ def __repr__(self):
f"{type(self).__name__}(in_features={self.input_size}, "
f"out_features={self.output_size}, bias={use_bias}, TP={tp})"
)

class BitNetColumnParallelLinear(ColumnParallelLinear):
"""
BitNet-enabled Column Parallel Linear Layer.

Extends ColumnParallelLinear with BitNet quantization while maintaining
all tensor parallelism functionality. BitNet quantization is applied
during training by overriding _forward_impl to quantize activations and weights.
"""

def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool = True,
gather_output: bool = False,
stride: int = 1,
keep_master_weight_for_test: bool = False,
skip_bias_add: bool = False,
skip_weight_param_allocation: bool = False,
embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
disable_grad_reduce: bool = False,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Initialize BitNet Column Parallel Linear layer.

BitNet quantization is automatically enabled when onebitllms is installed.
"""
super().__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
gather_output=gather_output,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=skip_weight_param_allocation,
embedding_activation_buffer=embedding_activation_buffer,
grad_output_buffer=grad_output_buffer,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
disable_grad_reduce=disable_grad_reduce,
tp_group=tp_group,
)

if not HAVE_ONEBIT:
raise ImportError(
"BitNet requires onebitllms to be installed. "
"Install with: pip install onebitllms"
)

def _forward_impl(self, input, weight, *args, **kwargs):
"""
Override parent's _forward_impl to apply BitNet quantization.
reference: https://github.com/tiiuae/onebitllms/blob/main/src/onebitllms/layers/bitnet.py

Uses the Straight-Through Estimator (STE) pattern from onebitllms:
x_quant = x + (quant(x) - x).detach()
"""
# Apply STE quantization pattern (matches onebitllms BitNetLinear.forward)
input_quantized = input + (activation_quant_triton(input) - input).detach()
weight_quantized = weight + (weight_quant_triton(weight) - weight).detach()

return super()._forward_impl(input_quantized, weight_quantized, *args, **kwargs)


class BitNetRowParallelLinear(RowParallelLinear):
"""
BitNet-enabled Row Parallel Linear Layer.

Extends RowParallelLinear with BitNet quantization while maintaining
all tensor parallelism functionality. BitNet quantization is applied
during training by overriding _forward_impl to quantize inputs and weights.
"""

def __init__(
self,
input_size: int,
output_size: int,
*,
config: ModelParallelConfig,
init_method: Callable,
bias: bool,
input_is_parallel: bool,
skip_bias_add: bool,
stride: int = 1,
keep_master_weight_for_test: bool = False,
is_expert: bool = False,
tp_comm_buffer_name: str = None,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
):
"""
Initialize BitNet Row Parallel Linear layer.
BitNet quantization is automatically enabled when onebitllms is installed.
"""
super().__init__(
input_size=input_size,
output_size=output_size,
config=config,
init_method=init_method,
bias=bias,
input_is_parallel=input_is_parallel,
skip_bias_add=skip_bias_add,
stride=stride,
keep_master_weight_for_test=keep_master_weight_for_test,
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
tp_group=tp_group,
)

if not HAVE_ONEBIT:
raise ImportError(
"BitNet requires onebitllms to be installed. "
"Install with: pip install onebitllms"
)

def _forward_impl(self, input, weight, *args, **kwargs):
"""
Override parent's _forward_impl to apply BitNet quantization.
reference: https://github.com/tiiuae/onebitllms/blob/main/src/onebitllms/layers/bitnet.py

Uses the Straight-Through Estimator (STE) pattern from onebitllms:
x_quant = x + (quant(x) - x).detach()
"""
# Apply STE quantization pattern (matches onebitllms BitNetLinear.forward)
input_quantized = input + (activation_quant_triton(input) - input).detach()
weight_quantized = weight + (weight_quant_triton(weight) - weight).detach()

return super()._forward_impl(input_quantized, weight_quantized, *args, **kwargs)
9 changes: 8 additions & 1 deletion megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,12 @@ def validate_args(args, defaults={}):
args.moe_ffn_hidden_size = args.ffn_hidden_size
if args.rank == 0:
print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.")


# BitNet validation
if args.use_bitnet:
assert args.transformer_impl == 'local', \
'BitNet training requires --transformer-impl local , we will use onebitllm triton kernels for BitNet fwd and bwd pass ' \

# Context parallel
if args.context_parallel_size > 1:
assert not args.use_legacy_models, "Context parallelism is not supported in legacy models."
Expand Down Expand Up @@ -2490,6 +2495,8 @@ def _add_training_args(parser):
'This will significantly affect speed of training and inference as the kernels are not full optimized.')
group.add_argument('--disable-jit-fuser', action='store_true',
help='Disable the JIT fuser.')
group.add_argument('--use-bitnet', action='store_true',
help='Pretraining bitnet')

return parser

Expand Down