From 98be274aed4db61088ee9f08135d51e6ae1ccb98 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Fri, 22 Aug 2025 10:07:56 +0000 Subject: [PATCH 1/6] core: more changes --- launch.sh | 163 +++++++++++ .../core/models/mamba/mamba_layer_specs.py | 41 ++- megatron/core/models/mamba/mamba_model.py | 5 + megatron/core/ssm/mamba_block.py | 15 +- .../core/ssm/mamba_hybrid_layer_allocation.py | 14 +- megatron/core/ssm/mamba_layer.py | 10 +- megatron/core/ssm/parallel_hybrid_layer.py | 265 ++++++++++++++++++ megatron/training/arguments.py | 4 +- pretrain_mamba.py | 3 + 9 files changed, 511 insertions(+), 9 deletions(-) create mode 100644 launch.sh create mode 100644 megatron/core/ssm/parallel_hybrid_layer.py diff --git a/launch.sh b/launch.sh new file mode 100644 index 00000000000..cbf1ddd04a2 --- /dev/null +++ b/launch.sh @@ -0,0 +1,163 @@ +# ========== Configuration ========== +NUM_NODES=1 # Total number of nodes +NODE_RANK=0 # Rank of this node +NUM_GPUS_PER_NODE=1 # GPUs per node +# CONFIG_FILE="configs/falconH1/falconh1_base_config.yaml" # Path to your YAML config +MASTER_ADDR="localhost" # Or master node IP in multi-node +MASTER_PORT=29501 + +# ========== Recommended Exports ========== +export CUDA_DEVICE_ORDER="PCI_BUS_ID" +export NCCL_PROTO="Simple,LL128" +export NCCL_DEBUG="INFO" +export NCCL_SOCKET_IFNAME="eth0" +export NCCL_NET_PLUGIN=none +export PYTHONUNBUFFERED="1" +export CUDA_LAUNCH_BLOCKING="1" +export CUDA_DEVICE_MAX_CONNECTIONS="1" + +# Optional: Enable wandb cloud syncing +# export WANDB_MODE=online + +# ========== Print setup ========== +echo "Launching with:" +echo " - MASTER_ADDR: $MASTER_ADDR" +echo " - MASTER_PORT: $MASTER_PORT" +echo " - NODE_RANK: $NODE_RANK / $NUM_NODES" +echo " - GPUs per node: $NUM_GPUS_PER_NODE" +echo " - Config file: $CONFIG_FILE" +# Parallelism configuration +TP=1 # Tensor Parallel +PP=1 # Pipeline Parallel +CP=1 # Context Parallel + +# Build experiment name with parallelism config +# EXP_NAME="500M_MLM_tp${TP}_pp${PP}_cp${CP}" +EXP_NAME="test" + +options="\ + --micro-batch-size 8 \ + --global-batch-size 512 \ + --rampup-batch-size 64 64 4882 \ + --train-samples 210449 \ + --data-path /home/aiccu/Megatron-LM-Internal/data/mambatron_same_data_processed_text_document + --data-cache-path /gcs/data/data-cache-path \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model tiiuae/Falcon-H1-0.5B-Instruct \ + --vocab-size 32784 \ + --make-vocab-size-divisible-by 1 \ + --tensorboard-dir /gcs/data/tok-dir \ + --log-validation-ppl-to-tensorboard \ + --log-timers-to-tensorboard \ + --log-throughput \ + --log-interval 10 \ + --no-mmap-bin-files \ + --split 1000,0,0 \ + --fp32-residual-connection \ + \ + --disable-bias-linear \ + --num-layers 72 \ + --hidden-size 1024 \ + --ffn-hidden-size 2048 \ + --num-attention-heads 8 \ + --group-query-attention \ + --num-query-groups 2 \ + --seq-length 2048 \ + --max-position-embeddings 2048 \ + --rotary-base 100000000000 + --position-embedding-type rope \ + --no-rope-fusion \ + --disable-bias-linear \ + \ + --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ + --mamba-state-dim 128 \ + --mamba-head-dim 64 \ + --mamba-num-groups ${TP} \ + --reset-position-ids \ + \ + --weight-decay 0.1 \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-16 \ + --use-distributed-optimizer \ + --clip-grad 1.0 \ + --bf16 \ + --init-method-std 0.02 \ + --lr 128e-5 \ + --lr-decay-style WSD \ + --lr-wsd-decay-samples 15137 \ + --lr-wsd-decay-style exponential \ + --min-lr 0.0 \ + --lr-warmup-init 0.0 \ + --lr-warmup-fraction 0.1 \ + --ckpt-format torch \ + \ + --tensor-model-parallel-size ${TP} \ + --pipeline-model-parallel-size ${PP} \ + --context-parallel-size ${CP} \ + --overlap-param-gather \ + --overlap-grad-reduce \ + --no-gradient-accumulation-fusion \ + --no-masked-softmax-fusion \ + \ + --attention-softmax-in-fp32 \ + --untie-embeddings-and-output-weights \ + --swiglu \ + --normalization RMSNorm \ + --norm-epsilon 1e-5 \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --use-flash-attn \ + \ + --distributed-timeout-minutes 90 \ + --num-workers 16 \ + --num-dataset-builder-threads 32 \ + \ + --no-create-attention-mask-in-dataloader \ + --mid-level-dataset-surplus 0.005 \ + \ + --parallel-hybrid-ratio 0.0 \ + \ + --save /gcs/data/save \ + --save-interval 420 \ + --wandb-project mlm-final-pr \ + --wandb-exp-name test_hf_conversion \ + \ + --dataloader-type single \ + --eval-iters 0 \ + --no-load-optim \ + --no-load-rng \ + --seed 52 \ + --override-opt_param-scheduler \ +" +# extra_options="\ +# --d-conv 4 \ +# --conv-init 1.0 \ +# --expand 1 \ +# --A-init-range 1 16 \ +# --rmsnorm \ +# --dt-min 0.001 \ +# --dt-max 0.1 \ +# --dt-init random \ +# --dt-scale 1.0 \ +# --dt-init-floor 1e-4 \ +# --conv-bias \ +# --chunk-size 128 \ +# " + + # --data-path /home/aiccu/Megatron-LM-Internal/data/merged_falcon_english_32k/merged_0 \ +# --wandb-exp-name mlm_500M_16k_seq_DL_VocabSizeCorrect + +# ========== Run ========== +source ~/miniconda3/etc/profile.d/conda.sh +conda activate megatron +# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 +export CUDA_VISIBLE_DEVICES=0 +$(which torchrun) \ + --nproc_per_node=$NUM_GPUS_PER_NODE \ + --nnodes=$NUM_NODES \ + --node_rank=$NODE_RANK \ + --master_addr=$MASTER_ADDR \ + --master_port=$MASTER_PORT \ + pretrain_mamba.py ${options} # ${extra_options} \ No newline at end of file diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 8ef4a2ab3e4..88dbab4d070 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -6,8 +6,10 @@ TERowParallelLinear, ) from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.transformer.identity_op import IdentityOp from megatron.core.ssm.mamba_block import MambaStack, MambaStackSubmodules -from megatron.core.ssm.mamba_layer import MambaLayer, MambaLayerSubmodules +from megatron.core.ssm.mamba_layer import MambaLayer,MambaLayerSubmodules +from megatron.core.ssm.parallel_hybrid_layer import ParallelHybridLayer, ParallelHybridLayerSubmodules from megatron.core.ssm.mamba_mixer import MambaMixer, MambaMixerSubmodules from megatron.core.ssm.mlp_layer import MLPLayer from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules @@ -64,5 +66,42 @@ mlp_bda=get_bias_dropout_add, ), ), + + parallel_hybrid_layer=ModuleSpec( + module=ParallelHybridLayer, + submodules=ParallelHybridLayerSubmodules( + input_layernorm=IdentityOp, + mamba_layer=ModuleSpec( + module=MambaMixer, + submodules=MambaMixerSubmodules( + in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear + ), + ), + attention_layer=ModuleSpec( + module=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + ), + ), + ), + pre_mlp_layernorm=IdentityOp, + mlp_layer=ModuleSpec( + module=MLPLayer, + submodules=TransformerLayerSubmodules( + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear + ), + ), + mlp_bda=get_bias_dropout_add, + ), + ), + ), + ), ), ) diff --git a/megatron/core/models/mamba/mamba_model.py b/megatron/core/models/mamba/mamba_model.py index 8aa328c85eb..fb5c36589de 100644 --- a/megatron/core/models/mamba/mamba_model.py +++ b/megatron/core/models/mamba/mamba_model.py @@ -31,6 +31,8 @@ class MambaModel(LanguageModule): (used with pipeline parallelism). Defaults to True. hybrid_attention_ratio (float, optional): The target ratio of attention layers to total layers + parallel_hybrid_ratio (float, optional): The target ratio of parallel hybrid + layers to total layers hybrid_mlp_ratio (float, optional): The target ratio of mlp layers to total layers hybrid_override_pattern (str, optional): The hybrid layer pattern to override with post_process (bool, optional): Include an output layer (used with pipeline parallelism). @@ -60,6 +62,7 @@ def __init__( max_sequence_length: int, pre_process: bool = True, hybrid_attention_ratio: float = 0.0, + parallel_hybrid_ratio: float = 0.0, hybrid_mlp_ratio: float = 0.0, hybrid_override_pattern: str = None, post_process: bool = True, @@ -84,6 +87,7 @@ def __init__( self.max_sequence_length = max_sequence_length self.pre_process = pre_process self.hybrid_attention_ratio = hybrid_attention_ratio + self.parallel_hybrid_ratio = parallel_hybrid_ratio self.hybrid_mlp_ratio = hybrid_mlp_ratio self.hybrid_override_pattern = hybrid_override_pattern self.post_process = post_process @@ -121,6 +125,7 @@ def __init__( self.config, pre_process=self.pre_process, hybrid_attention_ratio=self.hybrid_attention_ratio, + parallel_hybrid_ratio=self.parallel_hybrid_ratio, hybrid_mlp_ratio=self.hybrid_mlp_ratio, hybrid_override_pattern=self.hybrid_override_pattern, post_process=self.post_process, diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index c8c19e2cd3e..8c916b2dc54 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -29,6 +29,7 @@ from megatron.core.transformer.module import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_layer import TransformerLayer +from megatron.core.ssm.parallel_hybrid_layer import ParallelHybridLayer from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import WrappedTensor, deprecate_inference_params, make_viewless_tensor @@ -86,6 +87,7 @@ class MambaStackSubmodules: mamba_layer: Union[ModuleSpec, type] = IdentityOp attention_layer: Union[ModuleSpec, type] = IdentityOp mlp_layer: Union[ModuleSpec, type] = IdentityOp + parallel_hybrid_layer: Union[ModuleSpec, type] = IdentityOp class MambaStack(MegatronModule): @@ -123,6 +125,7 @@ def __init__( pre_process: bool = True, hybrid_attention_ratio: float = 0.0, hybrid_mlp_ratio: float = 0.0, + parallel_hybrid_ratio: float = 0.0, hybrid_override_pattern: str = None, post_layer_norm: bool = True, post_process: bool = True, @@ -146,11 +149,13 @@ def __init__( self.hybrid_attention_ratio = hybrid_attention_ratio self.hybrid_mlp_ratio = hybrid_mlp_ratio self.hybrid_override_pattern = hybrid_override_pattern + self.parallel_hybrid_ratio = parallel_hybrid_ratio layer_type_list = allocate_layers( self.config.num_layers, self.hybrid_attention_ratio, self.hybrid_mlp_ratio, + self.parallel_hybrid_ratio, self.hybrid_override_pattern, ) @@ -188,6 +193,14 @@ def __init__( layer_number=i + 1, model_comm_pgs=model_comm_pgs, ) + elif layer_type == LayerSymbols.PARALLEL: + layer = build_module( + submodules.parallel_hybrid_layer, + config=self.config, + layer_number=i + 1, + model_comm_pgs=model_comm_pgs, + ) + import pdb; pdb.set_trace() else: assert False, "unexpected layer_type" self.layers.append(layer) @@ -332,7 +345,7 @@ def forward( else nullcontext() ) with inner_fp8_context: - if isinstance(layer, TransformerLayer): + if isinstance(layer, (TransformerLayer, ParallelHybridLayer)): hidden_states, _ = layer( hidden_states=hidden_states, attention_mask=attention_mask, diff --git a/megatron/core/ssm/mamba_hybrid_layer_allocation.py b/megatron/core/ssm/mamba_hybrid_layer_allocation.py index 26972b5454b..ddd45967406 100644 --- a/megatron/core/ssm/mamba_hybrid_layer_allocation.py +++ b/megatron/core/ssm/mamba_hybrid_layer_allocation.py @@ -18,13 +18,20 @@ class Symbols: MAMBA = "M" ATTENTION = "*" MLP = "-" + PARALLEL = "P" VALID = {MAMBA, ATTENTION, MLP} def _allocate_auto( - total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float + total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float, target_parallel_hybrid_ratio: float ) -> list: # First, allocate attention (evenly spaced, starting and ending with mamba) + + # TODO: decide on the best allocation logic here + if target_parallel_hybrid_ratio > 0.0: + layer_type_list = [Symbols.PARALLEL] * total_layers_count + return layer_type_list + attention_layers_count: int = round(total_layers_count * target_attention_ratio) mamba_layers_count: int = total_layers_count - attention_layers_count mamba_sections_count: int = attention_layers_count + 1 @@ -85,15 +92,16 @@ def allocate_layers( total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float, + target_parallel_hybrid_ratio: float, override_pattern: str = None, ) -> list: assert total_layers_count > 0 assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0 assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0 - assert target_attention_ratio + target_mlp_ratio <= 1.0 + assert target_attention_ratio + target_mlp_ratio + target_parallel_hybrid_ratio <= 1.0 # Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio - layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio) + layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio, target_parallel_hybrid_ratio) if override_pattern is not None: layer_type_list_override = _allocate_override(total_layers_count, override_pattern) diff --git a/megatron/core/ssm/mamba_layer.py b/megatron/core/ssm/mamba_layer.py index 316e688f98b..9c96375ad3f 100644 --- a/megatron/core/ssm/mamba_layer.py +++ b/megatron/core/ssm/mamba_layer.py @@ -61,7 +61,9 @@ def __init__( submodules: MambaLayerSubmodules, layer_number: int = 1, residual_in_fp32=False, + use_norm: bool = True, model_comm_pgs: ModelCommProcessGroups = None, + ): """Initialize Mamba Layer.""" super().__init__(config) @@ -75,6 +77,7 @@ def __init__( self.layer_number = layer_number self.residual_in_fp32 = residual_in_fp32 self.hidden_dropout = config.hidden_dropout + self.use_norm = use_norm self.mixer = build_module( submodules.mixer, self.config, @@ -82,7 +85,9 @@ def __init__( layer_number=layer_number, model_comm_pgs=model_comm_pgs, ) - self.norm = build_module(submodules.norm, self.config, self.config.hidden_size) + + if self.use_norm: + self.norm = build_module(submodules.norm, self.config, self.config.hidden_size) self.mamba_bda = build_module(submodules.mamba_bda) self.bias_dropout_add_exec_handler = torch.enable_grad @@ -120,7 +125,8 @@ def forward( residual = residual.to(torch.float32) hidden_states = hidden_states.to(dtype=self.config.params_dtype) - hidden_states = self.norm(hidden_states) + if self.use_norm: + hidden_states = self.norm(hidden_states) mixer_out_with_bias = self.mixer(hidden_states, inference_context=inference_context) diff --git a/megatron/core/ssm/parallel_hybrid_layer.py b/megatron/core/ssm/parallel_hybrid_layer.py new file mode 100644 index 00000000000..909a135563a --- /dev/null +++ b/megatron/core/ssm/parallel_hybrid_layer.py @@ -0,0 +1,265 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Union +from typing import Optional, Tuple + +import torch +from torch import Tensor + +from megatron.core.inference.contexts import BaseInferenceContext +from megatron.core.process_groups_config import ModelCommProcessGroups +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.module import MegatronModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig + +# Import existing components to compose +from megatron.core.ssm.mamba_mixer import MambaMixerSubmodules +from megatron.core.transformer.attention import SelfAttentionSubmodules + + +@dataclass +class ParallelHybridLayerSubmodules: + """ + Configuration class for specifying the submodules of a Mamba layer. + + This class defines the structure and default implementations for various + components of a Mamba layer, allowing for flexible customization of the + layer's architecture. + + Args: + mamba_layer (Union[ModuleSpec, type]): Specification for the input layer normalization. + attention_layer (Union[ModuleSpec, type]): Specification for the along-sequence mixing mechanism. + mlp_layer (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation + after the mixer. + """ + + mamba_layer: Union[ModuleSpec, type] = IdentityOp + attention_layer: Union[ModuleSpec, type] = IdentityOp + mlp_layer: Union[ModuleSpec, type] = IdentityOp + pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + input_layernorm: Union[ModuleSpec, type] = IdentityOp + + +class ParallelHybridLayer(MegatronModule): + """ + A single Mamba layer. + + Mamba layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: ParallelHybridLayerSubmodules, + layer_number: int = 1, + residual_in_fp32=False, + model_comm_pgs: ModelCommProcessGroups = None, + ): + """Initialize Mamba Layer.""" + super().__init__(config) + assert model_comm_pgs is not None, "model_comm_pgs must be provided for MambaLayer" + + self.config = config + self.layer_number = layer_number + self.residual_in_fp32 = residual_in_fp32 + + # Hidden dropout for BDA + self.hidden_dropout = config.hidden_dropout + + # Pre-normalization layer + self.input_layernorm = build_module( + submodules.input_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + mamba_submodules = MambaMixerSubmodules( + in_proj=submodules.mamba_layer.submodules.in_proj, + out_proj=submodules.mamba_layer.submodules.out_proj, + ) + + self.mamba_mixer = build_module( + submodules.mamba_layer.module, # Should be MambaMixer + submodules=mamba_submodules, + config=self.config, + layer_number=layer_number, + d_model=self.config.hidden_size, + model_comm_pgs=model_comm_pgs + ) + + # Attention Component: Use existing SelfAttention + attention_optional_kwargs = {} + if self.config.context_parallel_size > 1 and self.config.cp_comm_type is not None: + if isinstance(self.config.cp_comm_type, list): + attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type[self.layer_number] + else: + attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type + model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups() + attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs + # Create submodules for SelfAttention - extract from main submodules + attention_submodules = SelfAttentionSubmodules( + linear_qkv=submodules.attention_layer.submodules.linear_qkv, + core_attention=submodules.attention_layer.submodules.core_attention, + linear_proj=submodules.attention_layer.submodules.linear_proj, + q_layernorm=getattr(submodules.attention_layer.submodules, 'q_layernorm', None), + k_layernorm=getattr(submodules.attention_layer.submodules, 'k_layernorm', None), + ) + + self.self_attention = build_module( + submodules.attention_layer.module, + submodules=attention_submodules, + config=self.config, + layer_number=self.layer_number, + **attention_optional_kwargs, + ) + + # Bias-Dropout-Add fusion + self.mamba_bda = build_module(submodules.mamba_layer.mamba_bda) + + self.pre_mlp_layernorm = build_module( + submodules.pre_mlp_layernorm, + config=self.config, + hidden_size=self.config.hidden_size, + eps=self.config.layernorm_epsilon, + ) + + self.bias_dropout_add_exec_handler = torch.enable_grad + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + inference_context: Optional[BaseInferenceContext] = None, + rotary_pos_emb: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + rotary_pos_cos: Optional[torch.Tensor] = None, + rotary_pos_sin: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + packed_seq_params: Optional[PackedSeqParams] = None, + sequence_len_offset: Optional[int] = None, + position_ids: Optional[torch.Tensor] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + ): + """ + Forward pass through the hybrid mixer using COMPOSITION. + + Pure orchestration - no inline reimplementation! + """ + + # Save residual connection + residual = hidden_states + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + # Pre-normalization + hidden_states = hidden_states.to(dtype=self.config.params_dtype) + # hidden_states = self.norm(hidden_states) + + # Execute components and collect outputs + outputs = [] + biases = [] + + # SSM Forward: Use existing MambaMixer + mamba_output, mamba_bias = self.mamba_mixer( + hidden_states*self.config.ssm_in_multiplier, + inference_context=inference_context, + position_ids=position_ids, + ) + outputs.append(mamba_output*self.config.ssm_out_multiplier) + if mamba_bias is not None: + biases.append(mamba_bias) + + # Attention Component: Use existing SelfAttention + attn_output, attn_bias = self.self_attention( + hidden_states*self.config.attention_in_multiplier, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + outputs.append(attn_output*self.config.attention_out_multiplier) + + if attn_bias is not None: + biases.append(attn_bias) + # Combine outputs + if len(outputs) == 0: + # Fallback to identity + combined_output = hidden_states + combined_bias = None + elif len(outputs) == 1: + # Single component active + combined_output = outputs[0] + combined_bias = biases[0] if biases else None + else: + # Multiple components - add them + combined_output = sum(outputs) + combined_bias = sum(biases) if biases else None + + # Bias-Dropout-Add fusion (residual connection) + out_with_bias = (combined_output, combined_bias) + + with self.bias_dropout_add_exec_handler(): + final_output = self.mamba_bda( + training=self.training, + fused=self.config.bias_dropout_fusion + )(out_with_bias, residual, self.hidden_dropout) + return final_output + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + """Allocate inference cache for active components.""" + caches = {} + + if self.use_mamba and self.mamba_mixer is not None: + mamba_cache = self.mamba_mixer.allocate_inference_cache( + batch_size, max_seqlen, dtype + ) + caches['mamba'] = mamba_cache + + if self.use_attention and self.self_attention is not None: + #need to be implemented + pass + + return caches + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): + """Provide a sharded state dictionary for distributed checkpointing.""" + from megatron.core.transformer.utils import sharded_state_dict_default + + sharded_state_dict = {} + + # Handle norm + if hasattr(self, 'norm') and self.norm is not None: + norm_sd = sharded_state_dict_default( + self.norm, f'{prefix}norm.', sharded_offsets, metadata + ) + sharded_state_dict.update(norm_sd) + + # Handle SSM component + if self.use_mamba and hasattr(self, 'mamba_mixer') and self.mamba_mixer is not None: + mamba_sd = sharded_state_dict_default( + self.mamba_mixer, f'{prefix}mamba_mixer.', sharded_offsets, metadata + ) + sharded_state_dict.update(mamba_sd) + + # Handle attention component + if self.use_attention and hasattr(self, 'self_attention') and self.self_attention is not None: + attn_sd = sharded_state_dict_default( + self.self_attention, f'{prefix}self_attention.', sharded_offsets, metadata + ) + sharded_state_dict.update(attn_sd) + + return sharded_state_dict diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 3ffaed91012..4b07b770d3b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1200,8 +1200,7 @@ def core_transformer_config_from_args(args, config_class=None): elif hasattr(args, 'kitchen_recipe_number') and args.kitchen_recipe_number is not None: kw_args['use_kitchen'] = True kw_args['quant_recipe'] = kitchen_quantization_recipe_config(args.kitchen_recipe_number) - - + # Return config. return config_class(**kw_args) @@ -3000,6 +2999,7 @@ def _add_experimental_args(parser): '--hidden-size * expand // --mamba-head-dim') group.add_argument('--is-hybrid-model', default=False, action="store_true", help='Indicates whether the model is a hybrid model.') + group.add_argument('--parallel-hybrid-ratio', type=float, default=0.0, help='Ratio of parallel hybrid layers.') group.add_argument('--disable-mamba-mem-eff-path', default=False, action="store_true", help='Disable Mamba efficient path.') group.add_argument('--yaml-cfg', type=str, default=None, diff --git a/pretrain_mamba.py b/pretrain_mamba.py index bfd245f5ac8..ca201562c1e 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -66,6 +66,8 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel: else: raise("You must provide a valid Mamba layer spec!") + import pdb; pdb.set_trace() + model = MambaModel( config=config, mamba_stack_spec=mamba_stack_spec, @@ -74,6 +76,7 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel: pre_process=pre_process, hybrid_attention_ratio=args.hybrid_attention_ratio, hybrid_mlp_ratio=args.hybrid_mlp_ratio, + parallel_hybrid_ratio=args.parallel_hybrid_ratio, hybrid_override_pattern=args.hybrid_override_pattern, post_process=post_process, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, From c8ec0e037b93e72c0c28d26c589b1888854aa617 Mon Sep 17 00:00:00 2001 From: Younes Belkada Date: Fri, 22 Aug 2025 11:30:05 +0000 Subject: [PATCH 2/6] more refactor --- .../core/models/mamba/mamba_layer_specs.py | 1 + megatron/core/ssm/mamba_block.py | 3 +- megatron/core/ssm/parallel_hybrid_layer.py | 41 +++++++++++++------ pretrain_mamba.py | 2 - 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 88dbab4d070..869444147e8 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -77,6 +77,7 @@ in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear ), ), + parallel_hybrid_bda=get_bias_dropout_add, attention_layer=ModuleSpec( module=ModuleSpec( module=SelfAttention, diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 8c916b2dc54..e07d47535fa 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -200,7 +200,6 @@ def __init__( layer_number=i + 1, model_comm_pgs=model_comm_pgs, ) - import pdb; pdb.set_trace() else: assert False, "unexpected layer_type" self.layers.append(layer) @@ -346,7 +345,7 @@ def forward( ) with inner_fp8_context: if isinstance(layer, (TransformerLayer, ParallelHybridLayer)): - hidden_states, _ = layer( + hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, inference_context=inference_context, diff --git a/megatron/core/ssm/parallel_hybrid_layer.py b/megatron/core/ssm/parallel_hybrid_layer.py index 909a135563a..095e9c921aa 100644 --- a/megatron/core/ssm/parallel_hybrid_layer.py +++ b/megatron/core/ssm/parallel_hybrid_layer.py @@ -46,6 +46,7 @@ class ParallelHybridLayerSubmodules: mlp_layer: Union[ModuleSpec, type] = IdentityOp pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp input_layernorm: Union[ModuleSpec, type] = IdentityOp + parallel_hybrid_bda: Union[ModuleSpec, type] = IdentityOp class ParallelHybridLayer(MegatronModule): @@ -106,13 +107,14 @@ def __init__( attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups() attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs + # Create submodules for SelfAttention - extract from main submodules attention_submodules = SelfAttentionSubmodules( - linear_qkv=submodules.attention_layer.submodules.linear_qkv, - core_attention=submodules.attention_layer.submodules.core_attention, - linear_proj=submodules.attention_layer.submodules.linear_proj, - q_layernorm=getattr(submodules.attention_layer.submodules, 'q_layernorm', None), - k_layernorm=getattr(submodules.attention_layer.submodules, 'k_layernorm', None), + linear_qkv=submodules.attention_layer.module.submodules.linear_qkv, + core_attention=submodules.attention_layer.module.submodules.core_attention, + linear_proj=submodules.attention_layer.module.submodules.linear_proj, + q_layernorm=getattr(submodules.attention_layer.module.submodules, 'q_layernorm', None), + k_layernorm=getattr(submodules.attention_layer.module.submodules, 'k_layernorm', None), ) self.self_attention = build_module( @@ -124,8 +126,9 @@ def __init__( ) # Bias-Dropout-Add fusion - self.mamba_bda = build_module(submodules.mamba_layer.mamba_bda) + self.parallel_hybrid_bda = build_module(submodules.parallel_hybrid_bda) + self.pre_mlp_layernorm = build_module( submodules.pre_mlp_layernorm, config=self.config, @@ -133,6 +136,13 @@ def __init__( eps=self.config.layernorm_epsilon, ) + self.mlp = build_module( + submodules.mlp_layer.module, + submodules=submodules.mlp_layer.submodules, + config=self.config, + layer_number=self.layer_number, + ) + self.bias_dropout_add_exec_handler = torch.enable_grad def forward( @@ -171,17 +181,16 @@ def forward( # SSM Forward: Use existing MambaMixer mamba_output, mamba_bias = self.mamba_mixer( - hidden_states*self.config.ssm_in_multiplier, + hidden_states, inference_context=inference_context, - position_ids=position_ids, ) - outputs.append(mamba_output*self.config.ssm_out_multiplier) + outputs.append(mamba_output) if mamba_bias is not None: biases.append(mamba_bias) # Attention Component: Use existing SelfAttention attn_output, attn_bias = self.self_attention( - hidden_states*self.config.attention_in_multiplier, + hidden_states, attention_mask=attention_mask, inference_context=inference_context, rotary_pos_emb=rotary_pos_emb, @@ -191,7 +200,7 @@ def forward( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, ) - outputs.append(attn_output*self.config.attention_out_multiplier) + outputs.append(attn_output) if attn_bias is not None: biases.append(attn_bias) @@ -213,10 +222,18 @@ def forward( out_with_bias = (combined_output, combined_bias) with self.bias_dropout_add_exec_handler(): - final_output = self.mamba_bda( + hidden_states = self.parallel_hybrid_bda( training=self.training, fused=self.config.bias_dropout_fusion )(out_with_bias, residual, self.hidden_dropout) + + # TODO: verify this + residual = hidden_states + hidden_states = self.pre_mlp_layernorm(hidden_states) + + hidden_states, _ = self.mlp(hidden_states) + final_output = hidden_states + residual + return final_output def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): diff --git a/pretrain_mamba.py b/pretrain_mamba.py index ca201562c1e..33e1a845cfd 100644 --- a/pretrain_mamba.py +++ b/pretrain_mamba.py @@ -66,8 +66,6 @@ def model_provider(pre_process=True, post_process=True) -> MambaModel: else: raise("You must provide a valid Mamba layer spec!") - import pdb; pdb.set_trace() - model = MambaModel( config=config, mamba_stack_spec=mamba_stack_spec, From 71a4a7e744c60067126c812ba0e7766c6ae52eef Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Mon, 1 Sep 2025 15:25:59 +0000 Subject: [PATCH 3/6] FalconH1 Layer Specs Restructuring + Hybrid Layer allocation + separate MLP --- launch.sh | 62 +++---- .../core/models/mamba/mamba_layer_specs.py | 18 +- megatron/core/ssm/mamba_block.py | 12 +- .../core/ssm/mamba_hybrid_layer_allocation.py | 85 +++++----- megatron/core/ssm/mamba_layer.py | 9 +- megatron/core/ssm/parallel_hybrid_layer.py | 158 +++++------------- 6 files changed, 117 insertions(+), 227 deletions(-) diff --git a/launch.sh b/launch.sh index cbf1ddd04a2..115dc4e93f9 100644 --- a/launch.sh +++ b/launch.sh @@ -1,9 +1,10 @@ +#!/bin/bash + # ========== Configuration ========== -NUM_NODES=1 # Total number of nodes -NODE_RANK=0 # Rank of this node -NUM_GPUS_PER_NODE=1 # GPUs per node -# CONFIG_FILE="configs/falconH1/falconh1_base_config.yaml" # Path to your YAML config -MASTER_ADDR="localhost" # Or master node IP in multi-node +NUM_NODES=1 +NODE_RANK=0 +NUM_GPUS_PER_NODE=1 +MASTER_ADDR="localhost" MASTER_PORT=29501 # ========== Recommended Exports ========== @@ -14,10 +15,7 @@ export NCCL_SOCKET_IFNAME="eth0" export NCCL_NET_PLUGIN=none export PYTHONUNBUFFERED="1" export CUDA_LAUNCH_BLOCKING="1" -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -# Optional: Enable wandb cloud syncing -# export WANDB_MODE=online +export CUDA_DEVICE_MAX_CONNECTIONS="1" # ========== Print setup ========== echo "Launching with:" @@ -25,14 +23,13 @@ echo " - MASTER_ADDR: $MASTER_ADDR" echo " - MASTER_PORT: $MASTER_PORT" echo " - NODE_RANK: $NODE_RANK / $NUM_NODES" echo " - GPUs per node: $NUM_GPUS_PER_NODE" -echo " - Config file: $CONFIG_FILE" + # Parallelism configuration TP=1 # Tensor Parallel -PP=1 # Pipeline Parallel -CP=1 # Context Parallel +PP=1 # Pipeline Parallel +CP=1 # Context Parallel # Build experiment name with parallelism config -# EXP_NAME="500M_MLM_tp${TP}_pp${PP}_cp${CP}" EXP_NAME="test" options="\ @@ -40,11 +37,11 @@ options="\ --global-batch-size 512 \ --rampup-batch-size 64 64 4882 \ --train-samples 210449 \ - --data-path /home/aiccu/Megatron-LM-Internal/data/mambatron_same_data_processed_text_document + --data-path /home/aiccu/Megatron-LM/data/merged_falcon_english_32k/merged_0 \ --data-cache-path /gcs/data/data-cache-path \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model tiiuae/Falcon-H1-0.5B-Instruct \ - --vocab-size 32784 \ + --vocab-size 32784 \ --make-vocab-size-divisible-by 1 \ --tensorboard-dir /gcs/data/tok-dir \ --log-validation-ppl-to-tensorboard \ @@ -64,7 +61,7 @@ options="\ --num-query-groups 2 \ --seq-length 2048 \ --max-position-embeddings 2048 \ - --rotary-base 100000000000 + --rotary-base 100000000000 \ --position-embedding-type rope \ --no-rope-fusion \ --disable-bias-linear \ @@ -73,7 +70,7 @@ options="\ --mamba-state-dim 128 \ --mamba-head-dim 64 \ --mamba-num-groups ${TP} \ - --reset-position-ids \ + --reset-position-ids \ \ --weight-decay 0.1 \ --optimizer adam \ @@ -117,47 +114,32 @@ options="\ --no-create-attention-mask-in-dataloader \ --mid-level-dataset-surplus 0.005 \ \ - --parallel-hybrid-ratio 0.0 \ + --parallel-hybrid-ratio 0.5 \ + --hybrid-attention-ratio 0.0 \ + --hybrid-mlp-ratio 0.5 \ \ --save /gcs/data/save \ --save-interval 420 \ --wandb-project mlm-final-pr \ - --wandb-exp-name test_hf_conversion \ + --wandb-exp-name final-pr \ \ + --disable-msc --dataloader-type single \ --eval-iters 0 \ --no-load-optim \ --no-load-rng \ --seed 52 \ - --override-opt_param-scheduler \ -" -# extra_options="\ -# --d-conv 4 \ -# --conv-init 1.0 \ -# --expand 1 \ -# --A-init-range 1 16 \ -# --rmsnorm \ -# --dt-min 0.001 \ -# --dt-max 0.1 \ -# --dt-init random \ -# --dt-scale 1.0 \ -# --dt-init-floor 1e-4 \ -# --conv-bias \ -# --chunk-size 128 \ -# " - - # --data-path /home/aiccu/Megatron-LM-Internal/data/merged_falcon_english_32k/merged_0 \ -# --wandb-exp-name mlm_500M_16k_seq_DL_VocabSizeCorrect + --override-opt_param-scheduler" # ========== Run ========== source ~/miniconda3/etc/profile.d/conda.sh conda activate megatron -# export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 export CUDA_VISIBLE_DEVICES=0 + $(which torchrun) \ --nproc_per_node=$NUM_GPUS_PER_NODE \ --nnodes=$NUM_NODES \ --node_rank=$NODE_RANK \ --master_addr=$MASTER_ADDR \ --master_port=$MASTER_PORT \ - pretrain_mamba.py ${options} # ${extra_options} \ No newline at end of file + pretrain_mamba.py ${options} \ No newline at end of file diff --git a/megatron/core/models/mamba/mamba_layer_specs.py b/megatron/core/models/mamba/mamba_layer_specs.py index 869444147e8..e60ec6ad842 100755 --- a/megatron/core/models/mamba/mamba_layer_specs.py +++ b/megatron/core/models/mamba/mamba_layer_specs.py @@ -70,15 +70,14 @@ parallel_hybrid_layer=ModuleSpec( module=ParallelHybridLayer, submodules=ParallelHybridLayerSubmodules( - input_layernorm=IdentityOp, - mamba_layer=ModuleSpec( + mamba_mixer=ModuleSpec( module=MambaMixer, submodules=MambaMixerSubmodules( in_proj=TELayerNormColumnParallelLinear, out_proj=TERowParallelLinear ), ), parallel_hybrid_bda=get_bias_dropout_add, - attention_layer=ModuleSpec( + self_attention=ModuleSpec( module=ModuleSpec( module=SelfAttention, params={"attn_mask_type": AttnMaskType.causal}, @@ -89,19 +88,6 @@ ), ), ), - pre_mlp_layernorm=IdentityOp, - mlp_layer=ModuleSpec( - module=MLPLayer, - submodules=TransformerLayerSubmodules( - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear - ), - ), - mlp_bda=get_bias_dropout_add, - ), - ), ), ), ), diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index e07d47535fa..74952fe8e23 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -197,7 +197,7 @@ def __init__( layer = build_module( submodules.parallel_hybrid_layer, config=self.config, - layer_number=i + 1, + layer_number=i + 1 + pp_layer_offset, model_comm_pgs=model_comm_pgs, ) else: @@ -344,7 +344,15 @@ def forward( else nullcontext() ) with inner_fp8_context: - if isinstance(layer, (TransformerLayer, ParallelHybridLayer)): + if isinstance(layer, TransformerLayer): + hidden_states, _ = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + sequence_len_offset=sequence_len_offset, + ) + if isinstance(layer, ParallelHybridLayer): hidden_states = layer( hidden_states=hidden_states, attention_mask=attention_mask, diff --git a/megatron/core/ssm/mamba_hybrid_layer_allocation.py b/megatron/core/ssm/mamba_hybrid_layer_allocation.py index ddd45967406..9332a18e73e 100644 --- a/megatron/core/ssm/mamba_hybrid_layer_allocation.py +++ b/megatron/core/ssm/mamba_hybrid_layer_allocation.py @@ -19,19 +19,12 @@ class Symbols: ATTENTION = "*" MLP = "-" PARALLEL = "P" - VALID = {MAMBA, ATTENTION, MLP} + VALID = {MAMBA, ATTENTION, MLP, PARALLEL} def _allocate_auto( total_layers_count: int, target_attention_ratio: float, target_mlp_ratio: float, target_parallel_hybrid_ratio: float ) -> list: - # First, allocate attention (evenly spaced, starting and ending with mamba) - - # TODO: decide on the best allocation logic here - if target_parallel_hybrid_ratio > 0.0: - layer_type_list = [Symbols.PARALLEL] * total_layers_count - return layer_type_list - attention_layers_count: int = round(total_layers_count * target_attention_ratio) mamba_layers_count: int = total_layers_count - attention_layers_count mamba_sections_count: int = attention_layers_count + 1 @@ -46,8 +39,6 @@ def _allocate_auto( else: x -= 1 - # Next, allocate mlp - # (evenly distributed, but right-justified, not replacing attention) mlp_layers_count: int = round(total_layers_count * target_mlp_ratio) if mlp_layers_count > 0: mamba_layers_count -= mlp_layers_count @@ -62,6 +53,26 @@ def _allocate_auto( else: x -= 1 + parallel_layers_count: int = round(total_layers_count * target_parallel_hybrid_ratio) + if parallel_layers_count > 0: + remaining_mamba_count = layer_type_list.count(Symbols.MAMBA) + if remaining_mamba_count > 0: + if parallel_layers_count >= remaining_mamba_count: + for l in range(total_layers_count): + if layer_type_list[l] == Symbols.MAMBA: + layer_type_list[l] = Symbols.PARALLEL + else: + mamba_to_parallel_ratio: float = (remaining_mamba_count - parallel_layers_count) / parallel_layers_count + + x: float = mamba_to_parallel_ratio + for l in range(total_layers_count): + if layer_type_list[l] == Symbols.MAMBA: + if x < 0.5: + layer_type_list[l] = Symbols.PARALLEL + x += mamba_to_parallel_ratio + else: + x -= 1 + return layer_type_list @@ -98,15 +109,15 @@ def allocate_layers( assert total_layers_count > 0 assert target_attention_ratio >= 0.0 and target_attention_ratio <= 1.0 assert target_mlp_ratio >= 0.0 and target_mlp_ratio <= 1.0 + assert target_parallel_hybrid_ratio >= 0.0 and target_parallel_hybrid_ratio <= 1.0 assert target_attention_ratio + target_mlp_ratio + target_parallel_hybrid_ratio <= 1.0 - # Note: target_mamba_ratio = 1.0 - target_attention_ratio - target_mlp_ratio layer_type_list = _allocate_auto(total_layers_count, target_attention_ratio, target_mlp_ratio, target_parallel_hybrid_ratio) if override_pattern is not None: layer_type_list_override = _allocate_override(total_layers_count, override_pattern) log_single_rank(logger, logging.INFO, "Using hybrid override pattern") - if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0) and not _layer_counts_match( + if (target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or target_parallel_hybrid_ratio > 0.0) and not _layer_counts_match( layer_type_list_override, layer_type_list ): raise ValueError( @@ -124,18 +135,21 @@ def allocate_layers( log_single_rank(logger, logging.INFO, f"B: {''.join(layer_type_list_override)}") layer_type_list = layer_type_list_override - if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or override_pattern is not None: + if target_attention_ratio > 0.0 or target_mlp_ratio > 0.0 or target_parallel_hybrid_ratio > 0.0 or override_pattern is not None: actual_attention_layers_count = layer_type_list.count(Symbols.ATTENTION) actual_attention_ratio = actual_attention_layers_count / total_layers_count actual_mlp_layers_count = layer_type_list.count(Symbols.MLP) actual_mlp_ratio = actual_mlp_layers_count / total_layers_count + actual_parallel_layers_count = layer_type_list.count(Symbols.PARALLEL) + actual_parallel_ratio = actual_parallel_layers_count / total_layers_count allocation_string = "".join(layer_type_list) log_single_rank( logger, logging.INFO, f"Hybrid allocation ({Symbols.MAMBA} is mamba, " f"{Symbols.ATTENTION} is attention, " - f"{Symbols.MLP} is mlp):", + f"{Symbols.MLP} is mlp, " + f"{Symbols.PARALLEL} is parallel):", ) log_single_rank(logger, logging.INFO, allocation_string) log_single_rank( @@ -161,39 +175,26 @@ def allocate_layers( f"Target mlp ratio: {target_mlp_ratio:.2f}. " f"Actual mlp ratio: {actual_mlp_ratio:.2f}.", ) + log_single_rank( + logger, + logging.INFO, + f"{actual_parallel_layers_count} parallel layers in " f"{total_layers_count} total layers.", + ) + log_single_rank( + logger, + logging.INFO, + f"Target parallel ratio: {target_parallel_hybrid_ratio:.2f}. " + f"Actual parallel ratio: {actual_parallel_ratio:.2f}.", + ) return layer_type_list if __name__ == "__main__": test_cases = [ - # (10, 0.2, 0.0), - # (48, 0.0, 0.0), # will not print anything - # (48, 0.1, 0.0), - # 48, 0.3, 0.0), - # (48, 0.5, 0.0), - # (48, 0.6, 0.0), - # (48, 0.7, 0.0), - # (10, 0.0, 0.1), - # (10, 0.0, 0.3), - # (10, 0.0, 0.5), - # (10, 0.1, 0.1), - # (10, 0.2, 0.2), - # (10, 0.3, 0.3), - # (10, 0.5, 0.5), - # (48, 0.2, 0.3), - # (48, 0.5, 0.2), - # (48, 0.5, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.25, 0.25, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.25, 0.25, "MM-*MM-*MM*-MM*-MM*-MM*-M*M-M*M-M*M-M*M-*MM-*MM-"), - # (48, 0.0, 0.2, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.2, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.0, 0.0, "MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-MM*-"), - # (48, 0.5, 0.5), - # (10, 0.3, 0.2, "MMM*-*M*M-"), - # (10, 0.3, 0.2, "MM*M-*M*M-"), - (9, 0.0, 0.0, "M*-M*-M*-"), - (9, 0.0, 0.0, "MMMMMMMMM"), + (9, 0.0, 0.0, 0.0, "M*-M*-M*-"), + (9, 0.0, 0.0, 0.0, "MMMMMMMMM"), + (10, 0.2, 0.1, 0.2), ] for t in test_cases: print("") - allocate_layers(*t) + allocate_layers(*t) \ No newline at end of file diff --git a/megatron/core/ssm/mamba_layer.py b/megatron/core/ssm/mamba_layer.py index 9c96375ad3f..04a8a177335 100644 --- a/megatron/core/ssm/mamba_layer.py +++ b/megatron/core/ssm/mamba_layer.py @@ -61,7 +61,6 @@ def __init__( submodules: MambaLayerSubmodules, layer_number: int = 1, residual_in_fp32=False, - use_norm: bool = True, model_comm_pgs: ModelCommProcessGroups = None, ): @@ -77,7 +76,6 @@ def __init__( self.layer_number = layer_number self.residual_in_fp32 = residual_in_fp32 self.hidden_dropout = config.hidden_dropout - self.use_norm = use_norm self.mixer = build_module( submodules.mixer, self.config, @@ -85,9 +83,7 @@ def __init__( layer_number=layer_number, model_comm_pgs=model_comm_pgs, ) - - if self.use_norm: - self.norm = build_module(submodules.norm, self.config, self.config.hidden_size) + self.norm = build_module(submodules.norm, self.config, self.config.hidden_size) self.mamba_bda = build_module(submodules.mamba_bda) self.bias_dropout_add_exec_handler = torch.enable_grad @@ -125,8 +121,7 @@ def forward( residual = residual.to(torch.float32) hidden_states = hidden_states.to(dtype=self.config.params_dtype) - if self.use_norm: - hidden_states = self.norm(hidden_states) + hidden_states = self.norm(hidden_states) mixer_out_with_bias = self.mixer(hidden_states, inference_context=inference_context) diff --git a/megatron/core/ssm/parallel_hybrid_layer.py b/megatron/core/ssm/parallel_hybrid_layer.py index 095e9c921aa..322615eea94 100644 --- a/megatron/core/ssm/parallel_hybrid_layer.py +++ b/megatron/core/ssm/parallel_hybrid_layer.py @@ -6,11 +6,9 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional, Union -from typing import Optional, Tuple +from typing import Optional, Union, Tuple import torch -from torch import Tensor from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.process_groups_config import ModelCommProcessGroups @@ -20,42 +18,21 @@ from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig -# Import existing components to compose from megatron.core.ssm.mamba_mixer import MambaMixerSubmodules from megatron.core.transformer.attention import SelfAttentionSubmodules @dataclass class ParallelHybridLayerSubmodules: - """ - Configuration class for specifying the submodules of a Mamba layer. - - This class defines the structure and default implementations for various - components of a Mamba layer, allowing for flexible customization of the - layer's architecture. - - Args: - mamba_layer (Union[ModuleSpec, type]): Specification for the input layer normalization. - attention_layer (Union[ModuleSpec, type]): Specification for the along-sequence mixing mechanism. - mlp_layer (Union[ModuleSpec, type]): Specification for the bias-dropout-add operation - after the mixer. - """ - - mamba_layer: Union[ModuleSpec, type] = IdentityOp - attention_layer: Union[ModuleSpec, type] = IdentityOp - mlp_layer: Union[ModuleSpec, type] = IdentityOp - pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp + """Configuration class for specifying the submodules of a parallel hybrid layer.""" + mamba_mixer: Union[ModuleSpec, type] = IdentityOp + self_attention: Union[ModuleSpec, type] = IdentityOp input_layernorm: Union[ModuleSpec, type] = IdentityOp parallel_hybrid_bda: Union[ModuleSpec, type] = IdentityOp class ParallelHybridLayer(MegatronModule): - """ - A single Mamba layer. - - Mamba layer takes input with size [s, b, h] and returns an - output of the same size. - """ + """A parallel hybrid layer that combines Mamba and Attention components.""" def __init__( self, @@ -65,18 +42,14 @@ def __init__( residual_in_fp32=False, model_comm_pgs: ModelCommProcessGroups = None, ): - """Initialize Mamba Layer.""" super().__init__(config) - assert model_comm_pgs is not None, "model_comm_pgs must be provided for MambaLayer" + assert model_comm_pgs is not None, "model_comm_pgs must be provided for ParallelHybridLayer" self.config = config self.layer_number = layer_number self.residual_in_fp32 = residual_in_fp32 - - # Hidden dropout for BDA self.hidden_dropout = config.hidden_dropout - # Pre-normalization layer self.input_layernorm = build_module( submodules.input_layernorm, config=self.config, @@ -85,12 +58,12 @@ def __init__( ) mamba_submodules = MambaMixerSubmodules( - in_proj=submodules.mamba_layer.submodules.in_proj, - out_proj=submodules.mamba_layer.submodules.out_proj, + in_proj=submodules.mamba_mixer.submodules.in_proj, + out_proj=submodules.mamba_mixer.submodules.out_proj, ) self.mamba_mixer = build_module( - submodules.mamba_layer.module, # Should be MambaMixer + submodules.mamba_mixer.module, submodules=mamba_submodules, config=self.config, layer_number=layer_number, @@ -98,7 +71,6 @@ def __init__( model_comm_pgs=model_comm_pgs ) - # Attention Component: Use existing SelfAttention attention_optional_kwargs = {} if self.config.context_parallel_size > 1 and self.config.cp_comm_type is not None: if isinstance(self.config.cp_comm_type, list): @@ -108,41 +80,23 @@ def __init__( model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups() attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs - # Create submodules for SelfAttention - extract from main submodules attention_submodules = SelfAttentionSubmodules( - linear_qkv=submodules.attention_layer.module.submodules.linear_qkv, - core_attention=submodules.attention_layer.module.submodules.core_attention, - linear_proj=submodules.attention_layer.module.submodules.linear_proj, - q_layernorm=getattr(submodules.attention_layer.module.submodules, 'q_layernorm', None), - k_layernorm=getattr(submodules.attention_layer.module.submodules, 'k_layernorm', None), + linear_qkv=submodules.self_attention.module.submodules.linear_qkv, + core_attention=submodules.self_attention.module.submodules.core_attention, + linear_proj=submodules.self_attention.module.submodules.linear_proj, + q_layernorm=getattr(submodules.self_attention.module.submodules, 'q_layernorm', None), + k_layernorm=getattr(submodules.self_attention.module.submodules, 'k_layernorm', None), ) self.self_attention = build_module( - submodules.attention_layer.module, + submodules.self_attention.module, submodules=attention_submodules, config=self.config, layer_number=self.layer_number, **attention_optional_kwargs, ) - # Bias-Dropout-Add fusion self.parallel_hybrid_bda = build_module(submodules.parallel_hybrid_bda) - - - self.pre_mlp_layernorm = build_module( - submodules.pre_mlp_layernorm, - config=self.config, - hidden_size=self.config.hidden_size, - eps=self.config.layernorm_epsilon, - ) - - self.mlp = build_module( - submodules.mlp_layer.module, - submodules=submodules.mlp_layer.submodules, - config=self.config, - layer_number=self.layer_number, - ) - self.bias_dropout_add_exec_handler = torch.enable_grad def forward( @@ -160,26 +114,16 @@ def forward( *, inference_params: Optional[BaseInferenceContext] = None, ): - """ - Forward pass through the hybrid mixer using COMPOSITION. - - Pure orchestration - no inline reimplementation! - """ - - # Save residual connection residual = hidden_states if self.residual_in_fp32: residual = residual.to(torch.float32) - # Pre-normalization hidden_states = hidden_states.to(dtype=self.config.params_dtype) - # hidden_states = self.norm(hidden_states) + hidden_states = self.input_layernorm(hidden_states) - # Execute components and collect outputs outputs = [] biases = [] - # SSM Forward: Use existing MambaMixer mamba_output, mamba_bias = self.mamba_mixer( hidden_states, inference_context=inference_context, @@ -188,7 +132,6 @@ def forward( if mamba_bias is not None: biases.append(mamba_bias) - # Attention Component: Use existing SelfAttention attn_output, attn_bias = self.self_attention( hidden_states, attention_mask=attention_mask, @@ -201,24 +144,12 @@ def forward( sequence_len_offset=sequence_len_offset, ) outputs.append(attn_output) - if attn_bias is not None: biases.append(attn_bias) - # Combine outputs - if len(outputs) == 0: - # Fallback to identity - combined_output = hidden_states - combined_bias = None - elif len(outputs) == 1: - # Single component active - combined_output = outputs[0] - combined_bias = biases[0] if biases else None - else: - # Multiple components - add them - combined_output = sum(outputs) - combined_bias = sum(biases) if biases else None - - # Bias-Dropout-Add fusion (residual connection) + + combined_output = sum(outputs) + combined_bias = sum(biases) if biases else None + out_with_bias = (combined_output, combined_bias) with self.bias_dropout_add_exec_handler(): @@ -227,56 +158,43 @@ def forward( fused=self.config.bias_dropout_fusion )(out_with_bias, residual, self.hidden_dropout) - # TODO: verify this - residual = hidden_states - hidden_states = self.pre_mlp_layernorm(hidden_states) - - hidden_states, _ = self.mlp(hidden_states) - final_output = hidden_states + residual - - return final_output + return hidden_states def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): - """Allocate inference cache for active components.""" + """Allocate inference cache for both components.""" caches = {} - if self.use_mamba and self.mamba_mixer is not None: + if self.mamba_mixer is not None: mamba_cache = self.mamba_mixer.allocate_inference_cache( batch_size, max_seqlen, dtype ) caches['mamba'] = mamba_cache - if self.use_attention and self.self_attention is not None: - #need to be implemented + if self.self_attention is not None: pass return caches def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): - """Provide a sharded state dictionary for distributed checkpointing.""" from megatron.core.transformer.utils import sharded_state_dict_default sharded_state_dict = {} - # Handle norm - if hasattr(self, 'norm') and self.norm is not None: - norm_sd = sharded_state_dict_default( - self.norm, f'{prefix}norm.', sharded_offsets, metadata - ) - sharded_state_dict.update(norm_sd) + norm_sd = sharded_state_dict_default( + self.input_layernorm, f'{prefix}input_layernorm.', sharded_offsets, metadata + ) + sharded_state_dict.update(norm_sd) - # Handle SSM component - if self.use_mamba and hasattr(self, 'mamba_mixer') and self.mamba_mixer is not None: - mamba_sd = sharded_state_dict_default( - self.mamba_mixer, f'{prefix}mamba_mixer.', sharded_offsets, metadata - ) - sharded_state_dict.update(mamba_sd) + mamba_sd = sharded_state_dict_default( + self.mamba_mixer, f'{prefix}mamba_mixer.', sharded_offsets, metadata + ) + sharded_state_dict.update(mamba_sd) - # Handle attention component - if self.use_attention and hasattr(self, 'self_attention') and self.self_attention is not None: - attn_sd = sharded_state_dict_default( - self.self_attention, f'{prefix}self_attention.', sharded_offsets, metadata - ) - sharded_state_dict.update(attn_sd) + attn_sd = sharded_state_dict_default( + self.self_attention, f'{prefix}self_attention.', sharded_offsets, metadata + ) + sharded_state_dict.update(attn_sd) return sharded_state_dict + + \ No newline at end of file From e82671298af8db194b4e92104485781e9e530b51 Mon Sep 17 00:00:00 2001 From: dhia eddine rhaiem Date: Tue, 2 Sep 2025 18:12:19 +0000 Subject: [PATCH 4/6] Convert to HF: FalconH1 loader and saver --- tools/checkpoint/loader_parallelhybrid.py | 319 +++++++++++++ tools/checkpoint/saver_parallelhybrid_hf.py | 500 ++++++++++++++++++++ 2 files changed, 819 insertions(+) create mode 100644 tools/checkpoint/loader_parallelhybrid.py create mode 100644 tools/checkpoint/saver_parallelhybrid_hf.py diff --git a/tools/checkpoint/loader_parallelhybrid.py b/tools/checkpoint/loader_parallelhybrid.py new file mode 100644 index 00000000000..b32f76215a3 --- /dev/null +++ b/tools/checkpoint/loader_parallelhybrid.py @@ -0,0 +1,319 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import json +import os +import sys +import torch +import types + +from loader_base import MegatronCheckpointLoaderBase + + +def add_arguments(parser): + """Add command-line arguments relevant to Falcon-H1 model loading.""" + group = parser.add_argument_group(title='Falcon-H1 loader') + + group.add_argument('--true-vocab-size', type=int, default=None, + help='Original size of vocab; if specified, trims padding from embedding table.') + group.add_argument('--vocab-file', type=str, default=None, + help='Path to a vocab file. If specified, determines vocab size to trim padding.') + group.add_argument('--megatron-path', type=str, default=None, + help='Base directory of Megatron repository') + group.add_argument('--position-embedding-type', + type=str, + default='learned_absolute', + choices=['learned_absolute', 'rope'], + help='Type of position embedding.') + group.add_argument('--loader-transformer-impl', default='local', + choices=['local', 'transformer_engine'], + help='Which Transformer implementation to use.') + + +class MegatronCheckpointLoaderFalconH1(MegatronCheckpointLoaderBase): + """ + Falcon-H1 specific checkpoint loader that handles hybrid architecture + with alternating Mamba+Attention layers and MLP-only layers. + + Architecture: + - Even layers (0,2,4,...): Hybrid (Mamba mixer + Self-attention) + - Odd layers (1,3,5,...): MLP-only + """ + + def build_sys_argv(self): + """ + Construct a sys.argv list for Megatron's argument parser. + """ + return [ + *super().build_sys_argv(), + '--position-embedding-type', self.args.position_embedding_type, + ] + + def import_model_provider(self): + """Return the Mamba model provider for Falcon-H1.""" + from pretrain_mamba import model_provider + return model_provider + + def is_hybrid_layer(self, layer_idx): + """Determine if a layer is hybrid (Mamba + Attention) or MLP-only.""" + return layer_idx % 2 == 0 + + def extract_mamba_weights(self, model, layer_idx): + """Extract Mamba mixer weights from a hybrid layer.""" + layer_name = f"decoder.layers.{layer_idx}.mamba_mixer" + + mamba_weights = {} + + # Get the mamba mixer module + mamba_mixer = None + for name, module in model.named_modules(): + if name == layer_name: + mamba_mixer = module + break + + if mamba_mixer is None: + raise ValueError(f"Could not find mamba_mixer at layer {layer_idx}") + + # Extract Mamba-specific parameters + mamba_weights["A_log"] = getattr(mamba_mixer, 'A_log', None) + mamba_weights["D"] = getattr(mamba_mixer, 'D', None) + mamba_weights["dt_bias"] = getattr(mamba_mixer, 'dt_bias', None) + + # Conv1D weights + if hasattr(mamba_mixer, 'conv1d'): + mamba_weights["conv1d_weight"] = mamba_mixer.conv1d.weight + mamba_weights["conv1d_bias"] = mamba_mixer.conv1d.bias + + # Input and output projections + if hasattr(mamba_mixer, 'in_proj'): + mamba_weights["in_proj_weight"] = mamba_mixer.in_proj.weight + # Note: pre_norm_weight is extracted separately above + + if hasattr(mamba_mixer, 'out_proj'): + mamba_weights["out_proj_weight"] = mamba_mixer.out_proj.weight + + # Norm weights - GET BOTH TYPES + if hasattr(mamba_mixer, 'norm'): + mamba_weights["internal_norm_weight"] = mamba_mixer.norm.weight + + # Pre-norm weight (from in_proj layer norm) + if hasattr(mamba_mixer, 'in_proj') and hasattr(mamba_mixer.in_proj, 'layer_norm_weight'): + mamba_weights["pre_norm_weight"] = mamba_mixer.in_proj.layer_norm_weight + + return mamba_weights + + def extract_attention_weights(self, model, layer_idx): + """Extract self-attention weights from a hybrid layer.""" + layer_name = f"decoder.layers.{layer_idx}.self_attention" + + attention_weights = {} + + # Get the self attention module + self_attention = None + for name, module in model.named_modules(): + if name == layer_name: + self_attention = module + break + + if self_attention is None: + raise ValueError(f"Could not find self_attention at layer {layer_idx}") + + # QKV projection + if hasattr(self_attention, 'linear_qkv'): + attention_weights["qkv_weight"] = self_attention.linear_qkv.weight + attention_weights["qkv_norm_weight"] = getattr(self_attention.linear_qkv, 'layer_norm_weight', None) + + # Output projection + if hasattr(self_attention, 'linear_proj'): + attention_weights["proj_weight"] = self_attention.linear_proj.weight + + return attention_weights + + def extract_mlp_weights(self, model, layer_idx): + """Extract MLP weights from an MLP-only layer.""" + layer_name = f"decoder.layers.{layer_idx}.mlp" + + mlp_weights = {} + + # Get the MLP module + mlp = None + for name, module in model.named_modules(): + if name == layer_name: + mlp = module + break + + if mlp is None: + raise ValueError(f"Could not find mlp at layer {layer_idx}") + + # FC1 (first linear layer) + if hasattr(mlp, 'linear_fc1'): + mlp_weights["fc1_weight"] = mlp.linear_fc1.weight + mlp_weights["fc1_norm_weight"] = getattr(mlp.linear_fc1, 'layer_norm_weight', None) + + # FC2 (second linear layer) + if hasattr(mlp, 'linear_fc2'): + mlp_weights["fc2_weight"] = mlp.linear_fc2.weight + + return mlp_weights + + def send_model_over_queue(self): + """Send Falcon-H1 model over the queue with proper hybrid layer handling.""" + # Send metadata first + self.send_metadata_over_queue() + + # Get model parameters + tp_size = self.margs.tensor_model_parallel_size + pp_size = self.margs.pipeline_model_parallel_size + vp_size = self.margs.virtual_pipeline_model_parallel_size or 1 + + # Get first pipeline models for embeddings/final norm + first_pipeline_models = self.all_models[0][0] + + # 1) Send embeddings + message = {} + for i, model in enumerate(first_pipeline_models): + # Extract embedding weights + for name, param in model.named_parameters(): + if 'embedding.word_embeddings.weight' in name: + if i == 0: + message["word embeddings"] = param + else: + message["word embeddings"] = torch.cat([message["word embeddings"], param], dim=0) + elif 'position_embeddings.weight' in name and self.md.position_embedding_type == 'learned_absolute': + if i == 0: # Only take from rank 0 + message["position embeddings"] = param + + if "position embeddings" not in message: + message["position embeddings"] = None + + self.queue_put("embeddings", message) + + # 2) Process each layer based on type + total_layer_num = 0 + for vp_rank in range(vp_size): + for pp_rank in range(pp_size): + models = self.all_models[pp_rank][vp_rank] + + # Determine number of layers in this model shard + model = models[0] + layer_count = 0 + max_layer_idx = -1 + for name, _ in model.named_parameters(): + if 'decoder.layers.' in name: + # Extract layer index + parts = name.split('.') + if len(parts) > 2 and parts[2].isdigit(): + layer_idx = int(parts[2]) + max_layer_idx = max(max_layer_idx, layer_idx) + + num_layers = max_layer_idx + 1 if max_layer_idx >= 0 else 0 + + for layer_idx in range(num_layers): + if self.is_hybrid_layer(total_layer_num): + # Process hybrid layer (Mamba + Attention) + message = {} + + # Collect Mamba weights across TP ranks + mamba_weights_per_rank = [] + attention_weights_per_rank = [] + + for model_tp in models: + mamba_weights = self.extract_mamba_weights(model_tp, layer_idx) + attention_weights = self.extract_attention_weights(model_tp, layer_idx) + mamba_weights_per_rank.append(mamba_weights) + attention_weights_per_rank.append(attention_weights) + + # Mamba components (typically not sharded across TP) + message["mamba A_log"] = mamba_weights_per_rank[0]["A_log"] + message["mamba D"] = mamba_weights_per_rank[0]["D"] + message["mamba dt_bias"] = mamba_weights_per_rank[0]["dt_bias"] + message["mamba conv1d weight"] = mamba_weights_per_rank[0]["conv1d_weight"] + message["mamba conv1d bias"] = mamba_weights_per_rank[0]["conv1d_bias"] + message["mamba pre norm weight"] = mamba_weights_per_rank[0]["pre_norm_weight"] + message["mamba internal norm weight"] = mamba_weights_per_rank[0]["internal_norm_weight"] + + # Mamba projections (may be sharded) + if len(mamba_weights_per_rank) > 1 and mamba_weights_per_rank[1]["in_proj_weight"] is not None: + # Concatenate across TP ranks + message["mamba in_proj weight"] = torch.cat([w["in_proj_weight"] for w in mamba_weights_per_rank], dim=0) + message["mamba out_proj weight"] = torch.cat([w["out_proj_weight"] for w in mamba_weights_per_rank], dim=1) + else: + message["mamba in_proj weight"] = mamba_weights_per_rank[0]["in_proj_weight"] + message["mamba out_proj weight"] = mamba_weights_per_rank[0]["out_proj_weight"] + + # Attention components (sharded across TP) + message["attention input norm weight"] = attention_weights_per_rank[0]["qkv_norm_weight"] + + # Concatenate QKV and dense weights across TP ranks + if len(attention_weights_per_rank) > 1: + message["attention qkv weight"] = torch.cat([w["qkv_weight"] for w in attention_weights_per_rank], dim=0) + message["attention dense weight"] = torch.cat([w["proj_weight"] for w in attention_weights_per_rank], dim=1) + else: + message["attention qkv weight"] = attention_weights_per_rank[0]["qkv_weight"] + message["attention dense weight"] = attention_weights_per_rank[0]["proj_weight"] + + self.queue_put(f"hybrid layer {total_layer_num}", message) + + else: + # Process MLP-only layer + message = {} + + # Collect MLP weights across TP ranks + mlp_weights_per_rank = [] + for model_tp in models: + mlp_weights = self.extract_mlp_weights(model_tp, layer_idx) + mlp_weights_per_rank.append(mlp_weights) + + # MLP norm (not sharded) + message["mlp input norm weight"] = mlp_weights_per_rank[0]["fc1_norm_weight"] + + # MLP weights (sharded across TP) + if len(mlp_weights_per_rank) > 1: + message["mlp fc1 weight"] = torch.cat([w["fc1_weight"] for w in mlp_weights_per_rank], dim=0) + message["mlp fc2 weight"] = torch.cat([w["fc2_weight"] for w in mlp_weights_per_rank], dim=1) + else: + message["mlp fc1 weight"] = mlp_weights_per_rank[0]["fc1_weight"] + message["mlp fc2 weight"] = mlp_weights_per_rank[0]["fc2_weight"] + + self.queue_put(f"mlp layer {total_layer_num}", message) + + total_layer_num += 1 + + # 3) Send final norm + message = {} + for name, param in models[0].named_parameters(): + if 'decoder.final_norm.weight' in name: + message["weight"] = param + break + self.queue_put("final norm", message) + + # 4) Send output layer + if self.md.output_layer: + message = {} + output_weights = [] + for model in models: + for name, param in model.named_parameters(): + if 'output_layer.weight' in name: + output_weights.append(param) + break + + if output_weights: + if len(output_weights) > 1: + message["weight"] = torch.cat(output_weights, dim=0) + else: + message["weight"] = output_weights[0] + self.queue_put("output layer", message) + + self.queue.put("done") + + +def load_checkpoint(queue, args): + """ + Required top-level function that creates the loader, + calls its .load(), and handles exceptions by signaling 'exit'. + """ + loader = MegatronCheckpointLoaderFalconH1(args, queue) + try: + loader.load() + except Exception as e: + queue.put("exit") + raise e \ No newline at end of file diff --git a/tools/checkpoint/saver_parallelhybrid_hf.py b/tools/checkpoint/saver_parallelhybrid_hf.py new file mode 100644 index 00000000000..0414fc2096c --- /dev/null +++ b/tools/checkpoint/saver_parallelhybrid_hf.py @@ -0,0 +1,500 @@ +import sys +import os +import gc +import math +import json +from pathlib import Path +from shutil import rmtree + +import torch +import torch.multiprocessing as mp +from transformers import ( + AutoModelForCausalLM, + FalconH1Config, + FalconH1ForCausalLM, + GenerationConfig, +) + +sys.path.append(os.path.abspath( + os.path.join(os.path.dirname(__file__), + os.path.pardir, + os.path.pardir))) +try: + from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding +except ModuleNotFoundError: + print("Unable to import Megatron. Exiting.") + exit(1) + +def add_arguments(parser): + group = parser.add_argument_group(title="Parallel Hybrid HF saver.") + group.add_argument( + "--hf-tokenizer", + type=str, + default=None, + help="HF tokenizer (example: tiiuae/Falcon-H1-0.5B-Instruct)", + ) + group.add_argument( + "--check-eq-hf", + type=str, + default=None, + help="check equality with HF model, example: tiiuae/Falcon-H1-1.5B-Instruct", + ) + group.add_argument( + "--save-chat-model", + action='store_true', + help="flag to save chat model or not", + ) + +def perform_check( + state_dict: dict[str, torch.Tensor], ref_state_dict: dict[str, torch.Tensor] +) -> dict[str, torch.Tensor]: + """ + Given a reference state dict, check that state_dict is equal to it + then pop the keys from ref_state_dict + """ + for key in state_dict: + if key in ref_state_dict: + if not torch.equal(ref_state_dict[key], state_dict[key]): + print(f"Warning: Mismatch found in {key}") + ref_state_dict.pop(key) + else: + print(f"Warning: Key {key} not found in reference model") + return ref_state_dict + +def save_layer( + state_dict: dict[str, torch.Tensor], + index_dict: dict, + dir_path: str, + filename: str, + check_reference: bool = False, + ref_state_dict: dict[str, torch.Tensor] = None, +) -> tuple[dict, dict[str, torch.Tensor]]: + """check state dict against a reference one if needed + update index_dict + save state dict + """ + if check_reference and ref_state_dict is not None: + ref_state_dict = perform_check(state_dict, ref_state_dict) + for layer_name, weight_matrix in state_dict.items(): + index_dict["weight_map"][layer_name] = filename + index_dict["metadata"]["total_size"] += weight_matrix.numel() + print(f"saving state dict to {dir_path}/{filename}") + torch.save(state_dict, f"{dir_path}/{filename}") + return index_dict, ref_state_dict + +def is_hybrid_layer(layer_idx: int) -> bool: + """Determine if a layer is hybrid (Mamba + Attention) or MLP-only""" + return layer_idx % 2 == 0 + +def process_hybrid_layer_weights(message: dict, layer_idx: int, falcon_h1_config: FalconH1Config) -> dict[str, torch.Tensor]: + """Process weights for hybrid layers (Mamba + Attention)""" + state_dict = {} + + # Mamba mixer components + state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = message["mamba A_log"] + state_dict[f"model.layers.{layer_idx}.mamba.D"] = message["mamba D"] + state_dict[f"model.layers.{layer_idx}.mamba.dt_bias"] = message["mamba dt_bias"] + state_dict[f"model.layers.{layer_idx}.mamba.conv1d.weight"] = message["mamba conv1d weight"] + state_dict[f"model.layers.{layer_idx}.mamba.conv1d.bias"] = message["mamba conv1d bias"] + state_dict[f"model.layers.{layer_idx}.mamba.in_proj.weight"] = message["mamba in_proj weight"] + state_dict[f"model.layers.{layer_idx}.mamba.out_proj.weight"] = message["mamba out_proj weight"] + + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = message["mamba pre norm weight"] + state_dict[f"model.layers.{layer_idx}.mamba.norm.weight"] = message["mamba internal norm weight"] + + # Self-attention components - PROPER QKV SPLITTING + qkv_weight = message["attention qkv weight"] + + # using standard Llama QKV layout + head_size = falcon_h1_config.hidden_size // falcon_h1_config.num_attention_heads # 128 + heads_per_group = falcon_h1_config.num_attention_heads // falcon_h1_config.num_key_value_heads # 4 + qkv_total_heads = falcon_h1_config.num_attention_heads + 2 * falcon_h1_config.num_key_value_heads # 12 + + # Reshape QKV to [12, 128, 1024] like Llama does + qkv_weights = qkv_weight.reshape([qkv_total_heads, head_size, falcon_h1_config.hidden_size]) + + # Create slices for Q, K, V exactly like Llama saver + q_slice = torch.cat([ + torch.arange( + (heads_per_group + 2) * i, + (heads_per_group + 2) * i + heads_per_group, + ) + for i in range(falcon_h1_config.num_key_value_heads) + ]) + k_slice = torch.arange(heads_per_group, qkv_total_heads, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_heads, (heads_per_group + 2)) + + # Extract Q, K, V using Llama's slicing approach + state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = qkv_weights[q_slice].reshape(-1, falcon_h1_config.hidden_size) + state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = qkv_weights[k_slice].reshape(-1, falcon_h1_config.hidden_size) + state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = qkv_weights[v_slice].reshape(-1, falcon_h1_config.hidden_size) + + # Attention output projection + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = message["attention dense weight"] + + # Attention layer norm + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = message["attention input norm weight"] + + return state_dict + +def process_mlp_layer_weights(message: dict, layer_idx: int, falcon_h1_config: FalconH1Config) -> dict[str, torch.Tensor]: + """Process weights for MLP-only layers""" + state_dict = {} + + # MLP components - FIXED NAMES TO FEED_FORWARD + mlp_fc1_weight = message["mlp fc1 weight"] + + # Split gate and up projections (assuming SwiGLU like Llama) + intermediate_size = falcon_h1_config.intermediate_size + + # Split the fc1 weight into gate_proj and up_proj + gate_proj_weight = mlp_fc1_weight[:intermediate_size, :] + up_proj_weight = mlp_fc1_weight[intermediate_size:, :] + + state_dict[f"model.layers.{layer_idx}.feed_forward.gate_proj.weight"] = gate_proj_weight + state_dict[f"model.layers.{layer_idx}.feed_forward.up_proj.weight"] = up_proj_weight + state_dict[f"model.layers.{layer_idx}.feed_forward.down_proj.weight"] = message["mlp fc2 weight"] + + # MLP layer norm - FIXED NAME + state_dict[f"model.layers.{layer_idx}.pre_ff_layernorm.weight"] = message["mlp input norm weight"] + + return state_dict + +def save_checkpoint(queue: mp.Queue, args): + def queue_get(name=None): + val = queue.get() + if val == "exit": + print("Loader exited, exiting saver") + exit(1) + if name is not None and args.checking and val["name"] != name: + val_name = val["name"] + print( + f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.' + ) + exit(1) + if name is not None: + print(f"received {name}") + return val + + md = queue_get() + + ### Verify compatibility of args + if not hasattr(md, "checkpoint_args"): + raise ValueError("missing checkpoint_args in metadata") + + # Falcon-H1 specific validations + if not hasattr(md.checkpoint_args, 'hybrid_architecture'): + print("Warning: hybrid_architecture not specified in checkpoint_args, assuming Falcon-H1") + + torch_dtype = torch.float32 + if md.checkpoint_args.bf16: + torch_dtype = torch.bfloat16 + if md.checkpoint_args.fp16: + raise ValueError("bf16 and fp16 cannot be both set.") + elif md.checkpoint_args.fp16: + torch_dtype = torch.float16 + if md.checkpoint_args.bf16: + raise ValueError("bf16 and fp16 cannot be both set.") + + ### init + save_dir = Path(args.save_dir) + tmp_save_dir = save_dir / "tmp" + save_dir.mkdir(exist_ok=True) + tmp_save_dir.mkdir(exist_ok=True) + index_dict = { + "weight_map": {}, + "metadata": {"total_size": 0}, + } + tokenizer = None + ref_state_dict = None + + ### prepare a reference model if needed + if args.check_eq_hf: + print(f"preparing checks with given HF model {args.check_eq_hf}") + ref_model = AutoModelForCausalLM.from_pretrained(args.check_eq_hf, trust_remote_code=True) + ref_state_dict = ref_model.state_dict() + + ### save tokenizer conf files + if args.hf_tokenizer: + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(args.hf_tokenizer) + print(f"saving tokenizer to {args.save_dir}") + tokenizer.save_pretrained(args.save_dir) + + ### save config.json + falcon_h1_config = FalconH1Config( + # Basic model parameters from checkpoint + vocab_size=md.true_vocab_size if md.true_vocab_size else md.checkpoint_args.padded_vocab_size, + hidden_size=md.checkpoint_args.hidden_size, + intermediate_size=md.checkpoint_args.ffn_hidden_size, + num_hidden_layers=md.checkpoint_args.num_layers, + num_attention_heads=md.checkpoint_args.num_attention_heads, + num_key_value_heads=md.checkpoint_args.num_query_groups, + max_position_embeddings=md.checkpoint_args.max_position_embeddings, + rms_norm_eps=md.checkpoint_args.norm_epsilon, + tie_word_embeddings=not md.checkpoint_args.untie_embeddings_and_output_weights, + attention_dropout=md.checkpoint_args.attention_dropout, + + # Mamba parameters from checkpoint + mamba_d_state=md.checkpoint_args.mamba_state_dim, + mamba_d_conv=md.checkpoint_args.d_conv, + mamba_expand=md.checkpoint_args.expand, + mamba_d_ssm=md.checkpoint_args.d_inner, + mamba_n_heads=md.checkpoint_args.d_inner // md.checkpoint_args.mamba_head_dim, + mamba_d_head=md.checkpoint_args.mamba_head_dim, + mamba_n_groups=md.checkpoint_args.mamba_num_groups, + mamba_chunk_size=md.checkpoint_args.chunk_size, + mamba_conv_bias=md.checkpoint_args.conv_bias, + mamba_proj_bias=md.checkpoint_args.add_bias_linear, + mamba_norm_before_gate=md.checkpoint_args.norm_before_gate, + mamba_rms_norm=md.checkpoint_args.rmsnorm, + + # RoPE parameters from checkpoint + rope_theta=md.checkpoint_args.rotary_base, + + # Bias parameters from checkpoint + attention_bias=md.checkpoint_args.add_bias_linear, + mlp_bias=md.checkpoint_args.add_bias_linear, + projectors_bias=md.checkpoint_args.add_bias_linear, + + # Token IDs - from tokenizer if available, otherwise defaults + pad_token_id=getattr(tokenizer, 'pad_token_id', 0) if tokenizer else 0, + bos_token_id=getattr(tokenizer, 'bos_token_id', 1) if tokenizer else 1, + eos_token_id=getattr(tokenizer, 'eos_token_id', 2) if tokenizer else 2, + + # Parameters using FalconH1Config defaults (not in checkpoint) + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + num_logits_to_keep=1, + rope_scaling=None, + + # Model metadata + torch_dtype=torch_dtype, + architectures=["FalconH1ForCausalLM"], + model_type="falcon_h1", + transformers_version="4.52.0", + ) + + if args.hf_tokenizer: + falcon_h1_config.eos_token_id = tokenizer.eos_token_id + falcon_h1_config.bos_token_id = tokenizer.bos_token_id + + print(f"saving config.json to {tmp_save_dir}") + falcon_h1_config.save_pretrained(tmp_save_dir) + + ### save embedding layer + def pad_weight(orig_word_embed, true_vocab_size): + if true_vocab_size is not None: + # figure out what our padded vocab size is + orig_vocab_size = orig_word_embed.shape[0] + md.checkpoint_args.padded_vocab_size = _vocab_size_with_padding(true_vocab_size, md.checkpoint_args) + + # Cut out extra padding we don't need + if orig_vocab_size > md.checkpoint_args.padded_vocab_size: + full_word_embed = orig_word_embed[0:md.checkpoint_args.padded_vocab_size,:] + + # Expanding embedding to larger size by replicating final entry + elif orig_vocab_size < md.checkpoint_args.padded_vocab_size: + padding_size = md.checkpoint_args.padded_vocab_size - orig_vocab_size + full_word_embed = torch.cat(( + orig_word_embed, + orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1))) + + # Same size! + else: + full_word_embed = orig_word_embed + else: + print("Original vocab size not specified, leaving embedding table as-is. " + "If you've changed the tensor parallel size this could cause problems.") + md.checkpoint_args.padded_vocab_size = orig_word_embed.shape[0] + full_word_embed = orig_word_embed + return full_word_embed + + state_dict = { + "model.embed_tokens.weight": pad_weight(queue_get("embeddings")["word embeddings"], md.true_vocab_size) + } + index_dict, ref_state_dict = save_layer( + state_dict, + index_dict, + dir_path=tmp_save_dir, + filename="pytorch_model-embedding.bin", + check_reference=args.check_eq_hf, + ref_state_dict=ref_state_dict, + ) + + for i_layer in range(falcon_h1_config.num_hidden_layers): + state_dict = {} + + if is_hybrid_layer(i_layer): + # Process hybrid layer (Mamba + Attention) - EVEN layers + message = queue_get(f"hybrid layer {i_layer}") + + # Add Mamba + Attention components from Megatron + hybrid_weights = process_hybrid_layer_weights(message, i_layer, falcon_h1_config) + state_dict.update(hybrid_weights) + + # Add MISSING MLP components (configured to output zeros = identity for addition) + mlp_intermediate_size = falcon_h1_config.intermediate_size + state_dict.update({ + # Gate and up can be anything since down_proj will zero everything out + f"model.layers.{i_layer}.feed_forward.gate_proj.weight": torch.randn( + mlp_intermediate_size, falcon_h1_config.hidden_size, + dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.feed_forward.up_proj.weight": torch.randn( + mlp_intermediate_size, falcon_h1_config.hidden_size, + dtype=torch_dtype + ) * 0.01, + # KEY: down_proj = 0 makes entire MLP output zero + f"model.layers.{i_layer}.feed_forward.down_proj.weight": torch.zeros( + falcon_h1_config.hidden_size, mlp_intermediate_size, + dtype=torch_dtype + ), + f"model.layers.{i_layer}.pre_ff_layernorm.weight": torch.ones( + falcon_h1_config.hidden_size, dtype=torch_dtype + ), + }) + + else: + # Process MLP-only layer - ODD layers + message = queue_get(f"mlp layer {i_layer}") + + # Add MLP components from Megatron + mlp_weights = process_mlp_layer_weights(message, i_layer, falcon_h1_config) + state_dict.update(mlp_weights) + + # Add MISSING Mamba components (configured to output zeros = identity for addition) + mamba_intermediate_size = ( + falcon_h1_config.mamba_d_ssm if falcon_h1_config.mamba_d_ssm + else int(falcon_h1_config.mamba_expand * falcon_h1_config.hidden_size) + ) + conv_dim = mamba_intermediate_size + 2 * falcon_h1_config.mamba_n_groups * falcon_h1_config.mamba_d_state + projection_size = mamba_intermediate_size + conv_dim + falcon_h1_config.mamba_n_heads + + state_dict.update({ + f"model.layers.{i_layer}.mamba.A_log": torch.log(torch.arange(1, falcon_h1_config.mamba_n_heads + 1, dtype=torch_dtype)), + f"model.layers.{i_layer}.mamba.D": torch.ones(falcon_h1_config.mamba_n_heads, dtype=torch_dtype), + f"model.layers.{i_layer}.mamba.dt_bias": torch.ones(falcon_h1_config.mamba_n_heads, dtype=torch_dtype), + f"model.layers.{i_layer}.mamba.conv1d.weight": torch.randn( + conv_dim, 1, falcon_h1_config.mamba_d_conv, dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.mamba.conv1d.bias": torch.zeros(conv_dim, dtype=torch_dtype), + f"model.layers.{i_layer}.mamba.in_proj.weight": torch.randn( + projection_size, falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + # KEY: out_proj = 0 makes entire Mamba output zero + f"model.layers.{i_layer}.mamba.out_proj.weight": torch.zeros( + falcon_h1_config.hidden_size, mamba_intermediate_size, dtype=torch_dtype + ), + f"model.layers.{i_layer}.mamba.norm.weight": torch.ones(mamba_intermediate_size, dtype=torch_dtype), + }) + + # Add MISSING Attention components (configured to output zeros = identity for addition) + head_dim = falcon_h1_config.hidden_size // falcon_h1_config.num_attention_heads + state_dict.update({ + f"model.layers.{i_layer}.self_attn.q_proj.weight": torch.randn( + falcon_h1_config.num_attention_heads * head_dim, + falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.self_attn.k_proj.weight": torch.randn( + falcon_h1_config.num_key_value_heads * head_dim, + falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + f"model.layers.{i_layer}.self_attn.v_proj.weight": torch.randn( + falcon_h1_config.num_key_value_heads * head_dim, + falcon_h1_config.hidden_size, dtype=torch_dtype + ) * 0.01, + # KEY: o_proj = 0 makes entire attention output zero + f"model.layers.{i_layer}.self_attn.o_proj.weight": torch.zeros( + falcon_h1_config.hidden_size, + falcon_h1_config.num_attention_heads * head_dim, + dtype=torch_dtype + ), + f"model.layers.{i_layer}.input_layernorm.weight": torch.ones( + falcon_h1_config.hidden_size, dtype=torch_dtype + ), + }) + index_dict, ref_state_dict = save_layer( + state_dict, + index_dict, + dir_path=tmp_save_dir, + filename=f"pytorch_model-{i_layer + 1}.bin", + check_reference=args.check_eq_hf, + ref_state_dict=ref_state_dict, + ) + + + ### save final norm and output layer + state_dict = { + "model.final_layernorm.weight": queue_get("final norm")["weight"] +} + if md.checkpoint_args.untie_embeddings_and_output_weights: + state_dict["lm_head.weight"] = pad_weight(queue_get("output layer")["weight"], md.true_vocab_size) + + index_dict, ref_state_dict = save_layer( + state_dict, + index_dict, + dir_path=tmp_save_dir, + filename="pytorch_model-lm-head.bin", + check_reference=args.check_eq_hf, + ref_state_dict=ref_state_dict, + ) + + # final check + if ref_state_dict: + remaining_keys = list(ref_state_dict.keys()) + print(f"Warning: reference state dict has {len(remaining_keys)} additional layers not present in converted model:") + for key in remaining_keys[:10]: # Show first 10 + print(f" - {key}") + if len(remaining_keys) > 10: + print(f" ... and {len(remaining_keys) - 10} more") + + ### save index dict + index_dict["metadata"]["total_size"] *= { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + }[torch_dtype] + print(f"saving {tmp_save_dir}/pytorch_model.bin.index.json") + with open(f"{tmp_save_dir}/pytorch_model.bin.index.json", "w") as f: + json.dump(index_dict, f) + + ### load then save model in HF format + # Make space so we can load the model properly now. + del state_dict + gc.collect() + print(f"Loading the converted pytorch checkpoint in a Falcon-H1 HF model from {tmp_save_dir}") + model = FalconH1ForCausalLM.from_pretrained( + str(tmp_save_dir), torch_dtype=torch_dtype, low_cpu_mem_usage=True, trust_remote_code=True + ) + + # Avoid saving this as part of the config. + if hasattr(model.config, '_name_or_path'): + del model.config._name_or_path + model.config.torch_dtype = torch_dtype + print(f"Saving in the Transformers safe tensors format to {args.save_dir}") + model.save_pretrained(args.save_dir, safe_serialization=True) + + ### save chat config + generation_config = ( + GenerationConfig( + do_sample=True, + temperature=0.6, + top_p=0.9, + bos_token_id=falcon_h1_config.bos_token_id, + eos_token_id=falcon_h1_config.eos_token_id, + ) + if args.save_chat_model + else GenerationConfig( + _from_model_config=True, + bos_token_id=falcon_h1_config.bos_token_id, + eos_token_id=falcon_h1_config.eos_token_id, + ) + ) + print(f"Saving generation config to {args.save_dir}") + generation_config.save_pretrained(args.save_dir) + + ### cleanup tmp + print(f"Deleting {tmp_save_dir}") + rmtree(tmp_save_dir) From 84fbf74a1fa8e3c336a67fa0d10cf3a3df6b865b Mon Sep 17 00:00:00 2001 From: dhia eddine rhaiem Date: Tue, 2 Sep 2025 18:59:09 +0000 Subject: [PATCH 5/6] clean --- launch.sh | 145 ------------------------------------------------------ 1 file changed, 145 deletions(-) delete mode 100644 launch.sh diff --git a/launch.sh b/launch.sh deleted file mode 100644 index 115dc4e93f9..00000000000 --- a/launch.sh +++ /dev/null @@ -1,145 +0,0 @@ -#!/bin/bash - -# ========== Configuration ========== -NUM_NODES=1 -NODE_RANK=0 -NUM_GPUS_PER_NODE=1 -MASTER_ADDR="localhost" -MASTER_PORT=29501 - -# ========== Recommended Exports ========== -export CUDA_DEVICE_ORDER="PCI_BUS_ID" -export NCCL_PROTO="Simple,LL128" -export NCCL_DEBUG="INFO" -export NCCL_SOCKET_IFNAME="eth0" -export NCCL_NET_PLUGIN=none -export PYTHONUNBUFFERED="1" -export CUDA_LAUNCH_BLOCKING="1" -export CUDA_DEVICE_MAX_CONNECTIONS="1" - -# ========== Print setup ========== -echo "Launching with:" -echo " - MASTER_ADDR: $MASTER_ADDR" -echo " - MASTER_PORT: $MASTER_PORT" -echo " - NODE_RANK: $NODE_RANK / $NUM_NODES" -echo " - GPUs per node: $NUM_GPUS_PER_NODE" - -# Parallelism configuration -TP=1 # Tensor Parallel -PP=1 # Pipeline Parallel -CP=1 # Context Parallel - -# Build experiment name with parallelism config -EXP_NAME="test" - -options="\ - --micro-batch-size 8 \ - --global-batch-size 512 \ - --rampup-batch-size 64 64 4882 \ - --train-samples 210449 \ - --data-path /home/aiccu/Megatron-LM/data/merged_falcon_english_32k/merged_0 \ - --data-cache-path /gcs/data/data-cache-path \ - --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model tiiuae/Falcon-H1-0.5B-Instruct \ - --vocab-size 32784 \ - --make-vocab-size-divisible-by 1 \ - --tensorboard-dir /gcs/data/tok-dir \ - --log-validation-ppl-to-tensorboard \ - --log-timers-to-tensorboard \ - --log-throughput \ - --log-interval 10 \ - --no-mmap-bin-files \ - --split 1000,0,0 \ - --fp32-residual-connection \ - \ - --disable-bias-linear \ - --num-layers 72 \ - --hidden-size 1024 \ - --ffn-hidden-size 2048 \ - --num-attention-heads 8 \ - --group-query-attention \ - --num-query-groups 2 \ - --seq-length 2048 \ - --max-position-embeddings 2048 \ - --rotary-base 100000000000 \ - --position-embedding-type rope \ - --no-rope-fusion \ - --disable-bias-linear \ - \ - --spec megatron.core.models.mamba.mamba_layer_specs mamba_stack_spec \ - --mamba-state-dim 128 \ - --mamba-head-dim 64 \ - --mamba-num-groups ${TP} \ - --reset-position-ids \ - \ - --weight-decay 0.1 \ - --optimizer adam \ - --adam-beta1 0.9 \ - --adam-beta2 0.95 \ - --adam-eps 1e-16 \ - --use-distributed-optimizer \ - --clip-grad 1.0 \ - --bf16 \ - --init-method-std 0.02 \ - --lr 128e-5 \ - --lr-decay-style WSD \ - --lr-wsd-decay-samples 15137 \ - --lr-wsd-decay-style exponential \ - --min-lr 0.0 \ - --lr-warmup-init 0.0 \ - --lr-warmup-fraction 0.1 \ - --ckpt-format torch \ - \ - --tensor-model-parallel-size ${TP} \ - --pipeline-model-parallel-size ${PP} \ - --context-parallel-size ${CP} \ - --overlap-param-gather \ - --overlap-grad-reduce \ - --no-gradient-accumulation-fusion \ - --no-masked-softmax-fusion \ - \ - --attention-softmax-in-fp32 \ - --untie-embeddings-and-output-weights \ - --swiglu \ - --normalization RMSNorm \ - --norm-epsilon 1e-5 \ - --attention-dropout 0.0 \ - --hidden-dropout 0.0 \ - --use-flash-attn \ - \ - --distributed-timeout-minutes 90 \ - --num-workers 16 \ - --num-dataset-builder-threads 32 \ - \ - --no-create-attention-mask-in-dataloader \ - --mid-level-dataset-surplus 0.005 \ - \ - --parallel-hybrid-ratio 0.5 \ - --hybrid-attention-ratio 0.0 \ - --hybrid-mlp-ratio 0.5 \ - \ - --save /gcs/data/save \ - --save-interval 420 \ - --wandb-project mlm-final-pr \ - --wandb-exp-name final-pr \ - \ - --disable-msc - --dataloader-type single \ - --eval-iters 0 \ - --no-load-optim \ - --no-load-rng \ - --seed 52 \ - --override-opt_param-scheduler" - -# ========== Run ========== -source ~/miniconda3/etc/profile.d/conda.sh -conda activate megatron -export CUDA_VISIBLE_DEVICES=0 - -$(which torchrun) \ - --nproc_per_node=$NUM_GPUS_PER_NODE \ - --nnodes=$NUM_NODES \ - --node_rank=$NODE_RANK \ - --master_addr=$MASTER_ADDR \ - --master_port=$MASTER_PORT \ - pretrain_mamba.py ${options} \ No newline at end of file From d35964ada0c326717589dcaff5fcf3d42d38766f Mon Sep 17 00:00:00 2001 From: "dhia.rhaiem" Date: Fri, 31 Oct 2025 02:44:01 +0000 Subject: [PATCH 6/6] make final commit for training and conversion --- mamba_builders.py | 1 + megatron/core/ssm/mamba_block.py | 4 +- megatron/core/ssm/parallel_hybrid_layer.py | 22 +-- megatron/training/checkpointing.py | 6 +- tools/checkpoint/convert.py | 2 +- tools/checkpoint/loader_base.py | 77 +++++++++- tools/checkpoint/saver_parallelhybrid_hf.py | 150 ++++++++++---------- tools/checkpoint/schema_core.py | 23 +++ 8 files changed, 190 insertions(+), 95 deletions(-) diff --git a/mamba_builders.py b/mamba_builders.py index 0ccfc29b86c..1cffbaf0f4e 100644 --- a/mamba_builders.py +++ b/mamba_builders.py @@ -27,6 +27,7 @@ def mamba_builder(args, pre_process, post_process, vp_stage=None, config=None): pre_process=pre_process, hybrid_attention_ratio=args.hybrid_attention_ratio, hybrid_mlp_ratio=args.hybrid_mlp_ratio, + parallel_hybrid_ratio=args.parallel_hybrid_ratio, hybrid_override_pattern=args.hybrid_override_pattern, post_process=post_process, fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, diff --git a/megatron/core/ssm/mamba_block.py b/megatron/core/ssm/mamba_block.py index 65466169e67..aade4447768 100644 --- a/megatron/core/ssm/mamba_block.py +++ b/megatron/core/ssm/mamba_block.py @@ -105,6 +105,8 @@ class MambaStack(MegatronModule): total layers. Defaults to 0.0. hybrid_mlp_ratio (float, optional): the target ratio of mlp layers to total layers. Defaults to 0.0. + parallel_hybrid_ratio (float, optional): the target ratio of parallel hybrid layers + (combined transformer+SSM) to total layers. Defaults to 0.0. hybrid_override_pattern (str, optional): the hybrid layer pattern to override with. Defaults to None. post_layer_norm (bool, optional): whether to include a final layer norm. @@ -198,7 +200,7 @@ def __init__( submodules.parallel_hybrid_layer, config=self.config, layer_number=i + 1 + pp_layer_offset, - model_comm_pgs=model_comm_pgs, + pg_collection=pg_collection, ) else: assert False, "unexpected layer_type" diff --git a/megatron/core/ssm/parallel_hybrid_layer.py b/megatron/core/ssm/parallel_hybrid_layer.py index 322615eea94..77487849ca2 100644 --- a/megatron/core/ssm/parallel_hybrid_layer.py +++ b/megatron/core/ssm/parallel_hybrid_layer.py @@ -11,14 +11,14 @@ import torch from megatron.core.inference.contexts import BaseInferenceContext -from megatron.core.process_groups_config import ModelCommProcessGroups +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.module import MegatronModule from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.ssm.mamba_mixer import MambaMixerSubmodules +from megatron.core.ssm.mamba_mixer import MambaMixerSubmodules from megatron.core.transformer.attention import SelfAttentionSubmodules @@ -40,10 +40,12 @@ def __init__( submodules: ParallelHybridLayerSubmodules, layer_number: int = 1, residual_in_fp32=False, - model_comm_pgs: ModelCommProcessGroups = None, + pg_collection: Optional[ProcessGroupCollection] = None, ): super().__init__(config) - assert model_comm_pgs is not None, "model_comm_pgs must be provided for ParallelHybridLayer" + if pg_collection is None: + pg_collection = ProcessGroupCollection.use_mpu_process_groups() + self.pg_collection = pg_collection self.config = config self.layer_number = layer_number @@ -51,8 +53,8 @@ def __init__( self.hidden_dropout = config.hidden_dropout self.input_layernorm = build_module( - submodules.input_layernorm, - config=self.config, + submodules.input_layernorm, + config=self.config, hidden_size=self.config.hidden_size, eps=self.config.layernorm_epsilon, ) @@ -68,7 +70,7 @@ def __init__( config=self.config, layer_number=layer_number, d_model=self.config.hidden_size, - model_comm_pgs=model_comm_pgs + pg_collection=pg_collection ) attention_optional_kwargs = {} @@ -77,8 +79,7 @@ def __init__( attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type[self.layer_number] else: attention_optional_kwargs["cp_comm_type"] = self.config.cp_comm_type - model_comm_pgs = ModelCommProcessGroups.use_mpu_process_groups() - attention_optional_kwargs["model_comm_pgs"] = model_comm_pgs + attention_optional_kwargs["pg_collection"] = pg_collection attention_submodules = SelfAttentionSubmodules( linear_qkv=submodules.self_attention.module.submodules.linear_qkv, @@ -154,7 +155,7 @@ def forward( with self.bias_dropout_add_exec_handler(): hidden_states = self.parallel_hybrid_bda( - training=self.training, + training=self.training, fused=self.config.bias_dropout_fusion )(out_with_bias, residual, self.hidden_dropout) @@ -197,4 +198,3 @@ def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None): return sharded_state_dict - \ No newline at end of file diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index 7c5487e05fd..ae96adc2f9f 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -1280,11 +1280,13 @@ def _set_arg(arg_name, old_arg_name=None, force=False): _set_arg('apply_query_key_layer_scaling', force=True) _set_arg('attention_dropout', force=True) _set_arg('hidden_dropout', force=True) + _set_arg('gradient_accumulation_fusion', force=True) _set_arg('hybrid_override_pattern', force=True) _set_arg('spec', force=True) _set_arg('hybrid_attention_ratio', force=True) _set_arg('hybrid_mlp_ratio', force=True) + _set_arg('parallel_hybrid_ratio', force=True) _set_arg('num_experts', force=True) _set_arg('moe_layer_freq', force=True) @@ -1360,7 +1362,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', strict=strict, load_arg=load_arg ) - + # Since load_modelopt_checkpoint doesn't return iteration count, we need to get it if torch.distributed.is_initialized(): tracker_filename = get_checkpoint_tracker_filename(load_dir) @@ -1372,7 +1374,7 @@ def load_checkpoint(ddp_model, optimizer, opt_param_scheduler, load_arg='load', iteration = 0 else: iteration = 0 - + # We don't have a reliable way to get num_floating_point_operations_so_far from ModelOpt format return iteration, 0 diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index e9eb7e99b60..6193ec2eb0d 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -112,7 +112,7 @@ def main(): allow_abbrev=False, conflict_handler='resolve') parser.add_argument('--model-type', type=str, required=True, - choices=['GPT', 'BERT'], + choices=['GPT', 'BERT', 'Mamba'], help='Type of the model') parser.add_argument('--loader', type=str, default='megatron', help='Module name to load checkpoint, should be on python path') diff --git a/tools/checkpoint/loader_base.py b/tools/checkpoint/loader_base.py index ef9d3688a69..c87e4f1bfec 100644 --- a/tools/checkpoint/loader_base.py +++ b/tools/checkpoint/loader_base.py @@ -68,7 +68,7 @@ def parse_megatron_args(self): # Expert parallelism requires sequence parallelism if margs.expert_model_parallel_size > 1: margs.sequence_parallel = True - + margs = self._maybe_parse_additional_megatron_args(margs, checkpoint_args) # Validate final arguments @@ -146,14 +146,72 @@ def initialize_megatron_env(self): mpu.set_pipeline_model_parallel_world_size(self.margs.pipeline_model_parallel_size) mpu.set_virtual_pipeline_model_parallel_world_size(self.margs.virtual_pipeline_model_parallel_size) mpu.set_expert_model_parallel_world_size(self.margs.expert_model_parallel_size) - + + # For backward compatibility during local parallel states refactoring fake_tp_group = _ConverterFakeProcessGroup(size=self.margs.tensor_model_parallel_size) fake_ep_group = _ConverterFakeProcessGroup(size=self.margs.expert_model_parallel_size) + # ADD: Create fake pp and cp groups + fake_pp_group = _ConverterFakeProcessGroup(size=self.margs.pipeline_model_parallel_size) + fake_cp_group = _ConverterFakeProcessGroup(size=1) # Context parallel is always 1 for conversion + + # Set all process groups mpu._TENSOR_MODEL_PARALLEL_GROUP = fake_tp_group mpu._EXPERT_MODEL_PARALLEL_GROUP = fake_ep_group + mpu._PIPELINE_MODEL_PARALLEL_GROUP = fake_pp_group + mpu._CONTEXT_PARALLEL_GROUP = fake_cp_group + + # Also set combined groups that might be needed + mpu._MODEL_PARALLEL_GROUP = fake_tp_group # Simplified: just use tp group + mpu._TENSOR_AND_CONTEXT_PARALLEL_GROUP = fake_tp_group # Combined tp+cp + + # ============================================ + # Add RNG Tracker Init for checkpoint conversion + # ============================================ + print("Initializing RNG tracker for checkpoint conversion...") + try: + from megatron.core.tensor_parallel.random import ( + get_cuda_rng_tracker, + initialize_rng_tracker, + model_parallel_cuda_manual_seed, + _MODEL_PARALLEL_RNG_TRACKER_NAME, + _DATA_PARALLEL_RNG_TRACKER_NAME + ) + + # Initialize RNG tracker (equivalent to what _set_random_seed does) + initialize_rng_tracker( + use_te_rng_tracker=getattr(self.margs, 'te_rng_tracker', False), + inference_rng_tracker=getattr(self.margs, 'inference_rng_tracker', False), + use_cudagraphable_rng=getattr(self.margs, 'enable_cuda_graph', False) + ) + + # Add required RNG states (this is what model_parallel_cuda_manual_seed does) + seed = getattr(self.margs, 'seed', 1234) + rng_tracker = get_cuda_rng_tracker() + + # Add model-parallel-rng state + if not hasattr(rng_tracker, '_states') or _MODEL_PARALLEL_RNG_TRACKER_NAME not in rng_tracker._states: + rng_tracker.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, seed) + print(f"Added {_MODEL_PARALLEL_RNG_TRACKER_NAME} RNG state with seed {seed}") + + # Add data-parallel-rng state + if not hasattr(rng_tracker, '_states') or _DATA_PARALLEL_RNG_TRACKER_NAME not in rng_tracker._states: + rng_tracker.add(_DATA_PARALLEL_RNG_TRACKER_NAME, seed) + print(f"Added {_DATA_PARALLEL_RNG_TRACKER_NAME} RNG state with seed {seed}") + + print("RNG tracker initialization completed successfully") + + # Debug: Show available RNG states + if hasattr(rng_tracker, '_states'): + print(f"Available RNG states: {list(rng_tracker._states.keys())}") + + except Exception as e: + print(f"Warning: Failed to initialize RNG tracker: {e}") + print("Model building may fail if RNG states are required") + # Don't exit - let it try without RNG tracker fused_kernels.load(self.margs) + def compute_true_vocab_size(self): """Determine the 'true' (non-padded) vocab size.""" if self.args.true_vocab_size is not None: @@ -205,7 +263,14 @@ def get_models_for_pipeline_stage(count, dtype): mpu.set_virtual_pipeline_model_parallel_rank(i) pre_process = mpu.is_pipeline_first_stage() post_process = mpu.is_pipeline_last_stage() - this_model = model_provider(pre_process=pre_process, + if self.args.model_type == 'Mamba': + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))) + from mamba_builders import mamba_builder + this_model = model_provider(mamba_builder, + pre_process=pre_process, + post_process=post_process).to(dtype) + else: + this_model = model_provider(pre_process=pre_process, post_process=post_process).to(dtype) model_list.append(this_model) @@ -242,7 +307,7 @@ def get_models_for_pipeline_stage(count, dtype): all_models.append(get_models_for_pipeline_stage(tp_size, dtype)) return all_models, consumed_train_samples, consumed_valid_samples - + def send_metadata_over_queue(self): # Let the consumer know the overall metadata: self.md.consumed_train_samples = self.consumed_train_samples @@ -401,7 +466,7 @@ def load(self): # 2) Ensure required arguments are present self.ensure_required_arguments() - # 3) Import the correct model provider (GPT or BERT) + # 3) Import the correct model provider (GPT or BERT or Mamba) model_provider = self.import_model_provider() # 4) Initialize the Megatron environment @@ -422,7 +487,7 @@ def load(self): self.md.params_dtype ) - # 8) Send model over the queue + # 8) Send model over the queue self.send_model_over_queue() def build_checkpoint_metadata(self, true_vocab_size): diff --git a/tools/checkpoint/saver_parallelhybrid_hf.py b/tools/checkpoint/saver_parallelhybrid_hf.py index 0414fc2096c..d833a9e9f76 100644 --- a/tools/checkpoint/saver_parallelhybrid_hf.py +++ b/tools/checkpoint/saver_parallelhybrid_hf.py @@ -89,7 +89,7 @@ def is_hybrid_layer(layer_idx: int) -> bool: def process_hybrid_layer_weights(message: dict, layer_idx: int, falcon_h1_config: FalconH1Config) -> dict[str, torch.Tensor]: """Process weights for hybrid layers (Mamba + Attention)""" state_dict = {} - + # Mamba mixer components state_dict[f"model.layers.{layer_idx}.mamba.A_log"] = message["mamba A_log"] state_dict[f"model.layers.{layer_idx}.mamba.D"] = message["mamba D"] @@ -98,21 +98,21 @@ def process_hybrid_layer_weights(message: dict, layer_idx: int, falcon_h1_config state_dict[f"model.layers.{layer_idx}.mamba.conv1d.bias"] = message["mamba conv1d bias"] state_dict[f"model.layers.{layer_idx}.mamba.in_proj.weight"] = message["mamba in_proj weight"] state_dict[f"model.layers.{layer_idx}.mamba.out_proj.weight"] = message["mamba out_proj weight"] - + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = message["mamba pre norm weight"] state_dict[f"model.layers.{layer_idx}.mamba.norm.weight"] = message["mamba internal norm weight"] - + # Self-attention components - PROPER QKV SPLITTING qkv_weight = message["attention qkv weight"] - + # using standard Llama QKV layout head_size = falcon_h1_config.hidden_size // falcon_h1_config.num_attention_heads # 128 heads_per_group = falcon_h1_config.num_attention_heads // falcon_h1_config.num_key_value_heads # 4 qkv_total_heads = falcon_h1_config.num_attention_heads + 2 * falcon_h1_config.num_key_value_heads # 12 - + # Reshape QKV to [12, 128, 1024] like Llama does qkv_weights = qkv_weight.reshape([qkv_total_heads, head_size, falcon_h1_config.hidden_size]) - + # Create slices for Q, K, V exactly like Llama saver q_slice = torch.cat([ torch.arange( @@ -123,41 +123,41 @@ def process_hybrid_layer_weights(message: dict, layer_idx: int, falcon_h1_config ]) k_slice = torch.arange(heads_per_group, qkv_total_heads, (heads_per_group + 2)) v_slice = torch.arange(heads_per_group + 1, qkv_total_heads, (heads_per_group + 2)) - + # Extract Q, K, V using Llama's slicing approach state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = qkv_weights[q_slice].reshape(-1, falcon_h1_config.hidden_size) state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = qkv_weights[k_slice].reshape(-1, falcon_h1_config.hidden_size) state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = qkv_weights[v_slice].reshape(-1, falcon_h1_config.hidden_size) - + # Attention output projection state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = message["attention dense weight"] - - # Attention layer norm + + # Attention layer norm state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = message["attention input norm weight"] - + return state_dict def process_mlp_layer_weights(message: dict, layer_idx: int, falcon_h1_config: FalconH1Config) -> dict[str, torch.Tensor]: """Process weights for MLP-only layers""" state_dict = {} - + # MLP components - FIXED NAMES TO FEED_FORWARD mlp_fc1_weight = message["mlp fc1 weight"] - + # Split gate and up projections (assuming SwiGLU like Llama) intermediate_size = falcon_h1_config.intermediate_size - + # Split the fc1 weight into gate_proj and up_proj gate_proj_weight = mlp_fc1_weight[:intermediate_size, :] up_proj_weight = mlp_fc1_weight[intermediate_size:, :] - + state_dict[f"model.layers.{layer_idx}.feed_forward.gate_proj.weight"] = gate_proj_weight state_dict[f"model.layers.{layer_idx}.feed_forward.up_proj.weight"] = up_proj_weight state_dict[f"model.layers.{layer_idx}.feed_forward.down_proj.weight"] = message["mlp fc2 weight"] - + # MLP layer norm - FIXED NAME state_dict[f"model.layers.{layer_idx}.pre_ff_layernorm.weight"] = message["mlp input norm weight"] - + return state_dict def save_checkpoint(queue: mp.Queue, args): @@ -181,11 +181,11 @@ def queue_get(name=None): ### Verify compatibility of args if not hasattr(md, "checkpoint_args"): raise ValueError("missing checkpoint_args in metadata") - + # Falcon-H1 specific validations if not hasattr(md.checkpoint_args, 'hybrid_architecture'): print("Warning: hybrid_architecture not specified in checkpoint_args, assuming Falcon-H1") - + torch_dtype = torch.float32 if md.checkpoint_args.bf16: torch_dtype = torch.bfloat16 @@ -222,64 +222,66 @@ def queue_get(name=None): tokenizer.save_pretrained(args.save_dir) ### save config.json + mamba_expand = 2 + mamba_intermediate_size = mamba_expand * md.checkpoint_args.hidden_size + mamba_n_heads = mamba_intermediate_size // md.checkpoint_args.mamba_head_dim falcon_h1_config = FalconH1Config( # Basic model parameters from checkpoint vocab_size=md.true_vocab_size if md.true_vocab_size else md.checkpoint_args.padded_vocab_size, - hidden_size=md.checkpoint_args.hidden_size, - intermediate_size=md.checkpoint_args.ffn_hidden_size, - num_hidden_layers=md.checkpoint_args.num_layers, - num_attention_heads=md.checkpoint_args.num_attention_heads, - num_key_value_heads=md.checkpoint_args.num_query_groups, - max_position_embeddings=md.checkpoint_args.max_position_embeddings, - rms_norm_eps=md.checkpoint_args.norm_epsilon, + hidden_size=md.checkpoint_args.hidden_size, + intermediate_size=md.checkpoint_args.ffn_hidden_size, + num_hidden_layers=md.checkpoint_args.num_layers, + num_attention_heads=md.checkpoint_args.num_attention_heads, + num_key_value_heads=md.checkpoint_args.num_query_groups, + max_position_embeddings=md.checkpoint_args.max_position_embeddings, + rms_norm_eps=md.checkpoint_args.norm_epsilon, tie_word_embeddings=not md.checkpoint_args.untie_embeddings_and_output_weights, - attention_dropout=md.checkpoint_args.attention_dropout, - + attention_dropout=md.checkpoint_args.attention_dropout, # Mamba parameters from checkpoint - mamba_d_state=md.checkpoint_args.mamba_state_dim, - mamba_d_conv=md.checkpoint_args.d_conv, - mamba_expand=md.checkpoint_args.expand, - mamba_d_ssm=md.checkpoint_args.d_inner, - mamba_n_heads=md.checkpoint_args.d_inner // md.checkpoint_args.mamba_head_dim, - mamba_d_head=md.checkpoint_args.mamba_head_dim, - mamba_n_groups=md.checkpoint_args.mamba_num_groups, - mamba_chunk_size=md.checkpoint_args.chunk_size, - mamba_conv_bias=md.checkpoint_args.conv_bias, - mamba_proj_bias=md.checkpoint_args.add_bias_linear, - mamba_norm_before_gate=md.checkpoint_args.norm_before_gate, - mamba_rms_norm=md.checkpoint_args.rmsnorm, - + mamba_d_state=md.checkpoint_args.mamba_state_dim, + mamba_d_head=md.checkpoint_args.mamba_head_dim, + mamba_n_groups=md.checkpoint_args.mamba_num_groups, + mamba_d_ssm=mamba_intermediate_size, + mamba_n_heads=mamba_n_heads, + mamba_expand=mamba_expand, + mamba_d_conv=4, + mamba_chunk_size=256, + mamba_conv_bias=True, + mamba_proj_bias=md.checkpoint_args.add_bias_linear, + mamba_norm_before_gate=True, + mamba_rms_norm=md.checkpoint_args.normalization == "RMSNorm", + # RoPE parameters from checkpoint - rope_theta=md.checkpoint_args.rotary_base, - + rope_theta=md.checkpoint_args.rotary_base, + # Bias parameters from checkpoint - attention_bias=md.checkpoint_args.add_bias_linear, - mlp_bias=md.checkpoint_args.add_bias_linear, - projectors_bias=md.checkpoint_args.add_bias_linear, - + attention_bias=md.checkpoint_args.add_bias_linear, + mlp_bias=md.checkpoint_args.add_bias_linear, + projectors_bias=md.checkpoint_args.add_bias_linear, + # Token IDs - from tokenizer if available, otherwise defaults pad_token_id=getattr(tokenizer, 'pad_token_id', 0) if tokenizer else 0, bos_token_id=getattr(tokenizer, 'bos_token_id', 1) if tokenizer else 1, eos_token_id=getattr(tokenizer, 'eos_token_id', 2) if tokenizer else 2, - + # Parameters using FalconH1Config defaults (not in checkpoint) - hidden_act="silu", - initializer_range=0.02, - use_cache=True, - num_logits_to_keep=1, - rope_scaling=None, - + hidden_act="silu", + initializer_range=0.02, + use_cache=True, + num_logits_to_keep=1, + rope_scaling=None, + # Model metadata torch_dtype=torch_dtype, architectures=["FalconH1ForCausalLM"], model_type="falcon_h1", transformers_version="4.52.0", ) - + if args.hf_tokenizer: falcon_h1_config.eos_token_id = tokenizer.eos_token_id falcon_h1_config.bos_token_id = tokenizer.bos_token_id - + print(f"saving config.json to {tmp_save_dir}") falcon_h1_config.save_pretrained(tmp_save_dir) @@ -325,53 +327,53 @@ def pad_weight(orig_word_embed, true_vocab_size): for i_layer in range(falcon_h1_config.num_hidden_layers): state_dict = {} - + if is_hybrid_layer(i_layer): # Process hybrid layer (Mamba + Attention) - EVEN layers message = queue_get(f"hybrid layer {i_layer}") - + # Add Mamba + Attention components from Megatron hybrid_weights = process_hybrid_layer_weights(message, i_layer, falcon_h1_config) state_dict.update(hybrid_weights) - + # Add MISSING MLP components (configured to output zeros = identity for addition) - mlp_intermediate_size = falcon_h1_config.intermediate_size + mlp_intermediate_size = falcon_h1_config.intermediate_size state_dict.update({ # Gate and up can be anything since down_proj will zero everything out f"model.layers.{i_layer}.feed_forward.gate_proj.weight": torch.randn( - mlp_intermediate_size, falcon_h1_config.hidden_size, + mlp_intermediate_size, falcon_h1_config.hidden_size, dtype=torch_dtype ) * 0.01, f"model.layers.{i_layer}.feed_forward.up_proj.weight": torch.randn( - mlp_intermediate_size, falcon_h1_config.hidden_size, + mlp_intermediate_size, falcon_h1_config.hidden_size, dtype=torch_dtype ) * 0.01, # KEY: down_proj = 0 makes entire MLP output zero f"model.layers.{i_layer}.feed_forward.down_proj.weight": torch.zeros( - falcon_h1_config.hidden_size, mlp_intermediate_size, + falcon_h1_config.hidden_size, mlp_intermediate_size, dtype=torch_dtype ), f"model.layers.{i_layer}.pre_ff_layernorm.weight": torch.ones( falcon_h1_config.hidden_size, dtype=torch_dtype ), }) - + else: # Process MLP-only layer - ODD layers message = queue_get(f"mlp layer {i_layer}") - + # Add MLP components from Megatron mlp_weights = process_mlp_layer_weights(message, i_layer, falcon_h1_config) state_dict.update(mlp_weights) - + # Add MISSING Mamba components (configured to output zeros = identity for addition) mamba_intermediate_size = ( - falcon_h1_config.mamba_d_ssm if falcon_h1_config.mamba_d_ssm + falcon_h1_config.mamba_d_ssm if falcon_h1_config.mamba_d_ssm else int(falcon_h1_config.mamba_expand * falcon_h1_config.hidden_size) ) conv_dim = mamba_intermediate_size + 2 * falcon_h1_config.mamba_n_groups * falcon_h1_config.mamba_d_state projection_size = mamba_intermediate_size + conv_dim + falcon_h1_config.mamba_n_heads - + state_dict.update({ f"model.layers.{i_layer}.mamba.A_log": torch.log(torch.arange(1, falcon_h1_config.mamba_n_heads + 1, dtype=torch_dtype)), f"model.layers.{i_layer}.mamba.D": torch.ones(falcon_h1_config.mamba_n_heads, dtype=torch_dtype), @@ -389,7 +391,7 @@ def pad_weight(orig_word_embed, true_vocab_size): ), f"model.layers.{i_layer}.mamba.norm.weight": torch.ones(mamba_intermediate_size, dtype=torch_dtype), }) - + # Add MISSING Attention components (configured to output zeros = identity for addition) head_dim = falcon_h1_config.hidden_size // falcon_h1_config.num_attention_heads state_dict.update({ @@ -414,7 +416,7 @@ def pad_weight(orig_word_embed, true_vocab_size): f"model.layers.{i_layer}.input_layernorm.weight": torch.ones( falcon_h1_config.hidden_size, dtype=torch_dtype ), - }) + }) index_dict, ref_state_dict = save_layer( state_dict, index_dict, @@ -431,7 +433,7 @@ def pad_weight(orig_word_embed, true_vocab_size): } if md.checkpoint_args.untie_embeddings_and_output_weights: state_dict["lm_head.weight"] = pad_weight(queue_get("output layer")["weight"], md.true_vocab_size) - + index_dict, ref_state_dict = save_layer( state_dict, index_dict, @@ -440,7 +442,7 @@ def pad_weight(orig_word_embed, true_vocab_size): check_reference=args.check_eq_hf, ref_state_dict=ref_state_dict, ) - + # final check if ref_state_dict: remaining_keys = list(ref_state_dict.keys()) @@ -468,7 +470,7 @@ def pad_weight(orig_word_embed, true_vocab_size): model = FalconH1ForCausalLM.from_pretrained( str(tmp_save_dir), torch_dtype=torch_dtype, low_cpu_mem_usage=True, trust_remote_code=True ) - + # Avoid saving this as part of the config. if hasattr(model.config, '_name_or_path'): del model.config._name_or_path @@ -494,7 +496,7 @@ def pad_weight(orig_word_embed, true_vocab_size): ) print(f"Saving generation config to {args.save_dir}") generation_config.save_pretrained(args.save_dir) - + ### cleanup tmp print(f"Deleting {tmp_save_dir}") rmtree(tmp_save_dir) diff --git a/tools/checkpoint/schema_core.py b/tools/checkpoint/schema_core.py index 529bef2525c..3449f600f97 100644 --- a/tools/checkpoint/schema_core.py +++ b/tools/checkpoint/schema_core.py @@ -11,6 +11,7 @@ def get_core_transformer_block_key(model_key): return { "GPT" : "decoder", "BERT" : "encoder", + "Mamba": "decoder" }[model_key] @@ -70,6 +71,17 @@ def __init__(self, model_type, prefix, extra_layer_schema): "mlp_fc2_weight" : "mlp.linear_fc2.weight", "mlp_fc2_bias" : "mlp.linear_fc2.bias", + # MambaMixer. + "mamba_mixer_norm_weight" : "mamba_mixer.in_proj.layer_norm_weight", + "mamba_mixer_in_proj_weight" : "mamba_mixer.in_proj.weight", + "mamba_mixer_conv1d_weight" : "mamba_mixer.conv1d.weight", + "mamba_mixer_conv1d_bias" : "mamba_mixer.conv1d.bias", + "mamba_mixer_out_proj_weight" : "mamba_mixer.out_proj.weight", + "mamba_mixer_dt_bias" : "mamba_mixer.dt_bias", + "mamba_mixer_A_log" : "mamba_mixer.A_log", + "mamba_mixer_D" : "mamba_mixer.D", + "mamba_mixer_internal_norm_weight" : "mamba_mixer.norm.weight", # RMSNorm inside mixer + } | extra_layer_schema, prefix=prefix) @@ -95,6 +107,17 @@ def __init__(self, model_type, prefix, extra_layer_schema): "mlp_fc2_weight" : "mlp.linear_fc2.weight", "mlp_fc2_bias" : "mlp.linear_fc2.bias", + # MambaMixer. + "mamba_mixer_norm_weight" : "mamba_mixer.in_proj.layer_norm_weight", + "mamba_mixer_in_proj_weight" : "mamba_mixer.in_proj.weight", + "mamba_mixer_conv1d_weight" : "mamba_mixer.conv1d.weight", + "mamba_mixer_conv1d_bias" : "mamba_mixer.conv1d.bias", + "mamba_mixer_out_proj_weight" : "mamba_mixer.out_proj.weight", + "mamba_mixer_dt_bias" : "mamba_mixer.dt_bias", + "mamba_mixer_A_log" : "mamba_mixer.A_log", + "mamba_mixer_D" : "mamba_mixer.D", + "mamba_mixer_internal_norm_weight" : "mamba_mixer.norm.weight", + } | extra_layer_schema, prefix=prefix)