diff --git a/.gitignore b/.gitignore index 22ec582c..a380c1a8 100644 --- a/.gitignore +++ b/.gitignore @@ -35,6 +35,7 @@ wheels/ /package /temp MANIFEST +.locks/ # PyInstaller # Usually these files are written by a python script from a template @@ -93,7 +94,6 @@ celerybeat-schedule *.sage.py # Environments -.locks .env .venv env/ diff --git a/cookbook/megatron/__init__.py b/cookbook/megatron/__init__.py new file mode 100644 index 00000000..49006762 --- /dev/null +++ b/cookbook/megatron/__init__.py @@ -0,0 +1 @@ +# Copyright (c) twinkle authors. All rights reserved. diff --git a/cookbook/megatron/lora.py b/cookbook/megatron/lora.py new file mode 100644 index 00000000..74a39156 --- /dev/null +++ b/cookbook/megatron/lora.py @@ -0,0 +1,195 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core LoRA training example. + +Supports both local (torchrun) and Ray execution modes. + +Usage (Local mode): + torchrun --nproc_per_node=4 cookbook/megatron/lora.py --tp_size 2 --pp_size 2 + +Usage (Ray mode): + TRUST_REMOTE_CODE=1 python cookbook/megatron/lora.py --mode ray --tp_size 2 --pp_size 2 --num_gpus 4 +""" +import argparse +import os + +import numpy as np +# CRITICAL: Set CUDA device before any CUDA imports (local mode only) +import torch +from peft import LoraConfig +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR + +import twinkle +from twinkle import (DeviceGroup, DeviceMesh, Platform, get_device_placement, + get_logger) +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import MegatronCrossEntropyLoss +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor + +# Parse arguments first to determine mode +parser = argparse.ArgumentParser() +parser.add_argument('--mode', + type=str, + default='local', + choices=['local', 'ray']) +parser.add_argument('--tp_size', type=int, default=1) +parser.add_argument('--pp_size', type=int, default=1) +parser.add_argument('--cp_size', type=int, default=1) +parser.add_argument('--num_gpus', + type=int, + default=4, + help='Number of GPUs (Ray mode only)') +parser.add_argument('--max_steps', type=int, default=None) +parser.add_argument('--model', + type=str, + default='ms://Qwen/Qwen2.5-7B-Instruct') +GAS = 16 # gradient accumulation steps +args = parser.parse_args() + +# Set mode in environment before importing twinkle +os.environ['TWINKLE_MODE'] = args.mode + +if args.mode == 'local': + LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) + torch.cuda.set_device(LOCAL_RANK) + +logger = get_logger() + + +def create_dataset(): + dataset = Dataset( + dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + dataset.set_template('Qwen3Template', + model_id='ms://Qwen/Qwen2.5-7B-Instruct') + dataset.map('CompetitionMathProcessor') + dataset.encode(batched=True, load_from_cache_file=False) + return dataset + + +def train(): + # Get parallelism config + TP_SIZE = args.tp_size + PP_SIZE = args.pp_size + CP_SIZE = args.cp_size + + if args.mode == 'local': + WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) + else: + WORLD_SIZE = args.num_gpus + + DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) + + # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost + device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), + mesh_dim_names=('pp', 'dp', 'cp', 'tp'), + ) + + # Device group name - used as remote_group in Ray mode + GROUP_NAME = 'model' + + device_group = [ + DeviceGroup( + name=GROUP_NAME, + ranks=list(range(WORLD_SIZE)), + device_type=Platform.get_platform().device_prefix(), + ) + ] + + twinkle.initialize( + mode=args.mode, + nproc_per_node=WORLD_SIZE, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, + ) + + # Use smaller batch size for single GPU to avoid OOM + batch_size = 2 if WORLD_SIZE == 1 else 8 + + # In Ray mode, pass remote_group and device_mesh + if args.mode == 'ray': + dataloader = DataLoader( + dataset=create_dataset, + batch_size=batch_size, + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + mixed_precision='bf16', + recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective', + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + else: + dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + mixed_precision='bf16', + recompute_granularity='full' if WORLD_SIZE <= 2 else 'selective', + ) + + lora_config = LoraConfig(target_modules='all-linear') + adapter_name = 'lora' + model.add_adapter_to_model(adapter_name, + lora_config, + gradient_accumulation_steps=GAS) + model.set_template('Qwen3Template', adapter_name=adapter_name) + model.set_processor(InputProcessor, + padding_side='right', + adapter_name=adapter_name) + model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name) + model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) + model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs(adapter_name=adapter_name)) + + for step, batch in enumerate(dataloader): + output = model.forward_backward(inputs=batch, + adapter_name=adapter_name) + if step % GAS == 0: + logger.info(f'Step {step // 16}, loss: {output}') + model.clip_grad_norm(1.0, adapter_name=adapter_name) + model.step(adapter_name=adapter_name) + model.zero_grad(adapter_name=adapter_name) + model.lr_step(adapter_name=adapter_name) + if step > 0 and step % (100 * GAS) == 0: + model.save('./output/megatron_lora', adapter_name=adapter_name) + # Early stop for testing + if args.max_steps and step >= args.max_steps * GAS: + logger.info(f'Reached max_steps ({args.max_steps}), stopping.') + break + model.save('./output/megatron_lora', adapter_name=adapter_name) + logger.info('Training completed!') + +def cleanup(): + """Clean up distributed resources.""" + import torch.distributed as dist + try: + if dist.is_initialized(): + dist.barrier() + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + mpu.destroy_model_parallel() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == '__main__': + try: + train() + finally: + cleanup() diff --git a/cookbook/megatron/moe_lora.py b/cookbook/megatron/moe_lora.py new file mode 100644 index 00000000..1cb72a7e --- /dev/null +++ b/cookbook/megatron/moe_lora.py @@ -0,0 +1,250 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core MoE (Mixture of Experts) LoRA training example. + +Supports Expert Parallel (EP) training in both local (torchrun) and Ray modes. + +Usage (Local mode with EP=2): + torchrun --nproc_per_node=4 cookbook/megatron/moe_lora.py --tp_size 2 --pp_size 1 --ep_size 2 + +Usage (Ray mode with EP=2): + TRUST_REMOTE_CODE=1 python cookbook/megatron/moe_lora.py --mode ray --tp_size 2 --pp_size 1 --ep_size 2 --num_gpus 4 +""" +import argparse +import os + +import numpy as np +# CRITICAL: Set CUDA device before any CUDA imports (local mode only) +import torch +from peft import LoraConfig +from torch.optim import AdamW +from torch.optim.lr_scheduler import LinearLR + +import twinkle +from twinkle import (DeviceGroup, DeviceMesh, Platform, get_device_placement, + get_logger) +from twinkle.dataloader import DataLoader +from twinkle.dataset import Dataset, DatasetMeta +from twinkle.loss import MegatronCrossEntropyLoss +from twinkle.model import MegatronModel +from twinkle.processor import InputProcessor +GAS = 16 # gradient accumulation steps +# Parse arguments first to determine mode +parser = argparse.ArgumentParser() +parser.add_argument('--mode', + type=str, + default='local', + choices=['local', 'ray']) +parser.add_argument('--tp_size', type=int, default=2) +parser.add_argument('--pp_size', type=int, default=1) +parser.add_argument('--cp_size', type=int, default=1) +parser.add_argument('--ep_size', + type=int, + default=2, + help='Expert parallel size') +parser.add_argument('--num_gpus', + type=int, + default=4, + help='Number of GPUs (Ray mode only)') +parser.add_argument('--max_steps', type=int, default=5) +parser.add_argument( + '--model', + type=str, + default='ms://Qwen/Qwen3-30B-A3B', + help='MoE model path. Default: Qwen3-30B-A3B (128 experts)') +parser.add_argument( + '--sequence_parallel', + action='store_true', + default=False, + help='Enable sequence parallel (auto-enabled for MoE with TP > 1)') +args = parser.parse_args() + +# Set mode in environment before importing twinkle +os.environ['TWINKLE_MODE'] = args.mode + +if args.mode == 'local': + LOCAL_RANK = int(os.environ.get('LOCAL_RANK', '0')) + torch.cuda.set_device(LOCAL_RANK) + +logger = get_logger() + + +def create_dataset(): + """Create dataset for MoE training.""" + dataset = Dataset( + dataset_meta=DatasetMeta('ms://modelscope/competition_math')) + # Use Qwen3 template for MoE model + dataset.set_template('Qwen3Template', model_id=args.model) + dataset.map('CompetitionMathProcessor') + dataset.encode(batched=True, load_from_cache_file=False) + return dataset + + +def train(): + # Get parallelism config + TP_SIZE = args.tp_size + PP_SIZE = args.pp_size + CP_SIZE = args.cp_size + EP_SIZE = args.ep_size + + if args.mode == 'local': + WORLD_SIZE = int(os.environ.get('WORLD_SIZE', '1')) + else: + WORLD_SIZE = args.num_gpus + + # DP calculation follows Megatron's logic: DP = world_size / (TP * PP * CP) + # EP is NOT included in DP calculation - it's handled separately by Megatron + # for MoE expert layers. Expert data parallel size is computed internally by Megatron. + DP_SIZE = WORLD_SIZE // (TP_SIZE * PP_SIZE * CP_SIZE) + + # Validate that world size supports the parallelism config + # For MoE, EP must divide the data parallel replicas correctly + if DP_SIZE < 1: + raise ValueError( + f'Not enough GPUs ({WORLD_SIZE}) for parallelism config: ' + f'TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}. ' + f'Required at least: {TP_SIZE * PP_SIZE * CP_SIZE}') + + # EP should divide into world_size / (TP * PP) for proper expert parallelism + # This ensures expert_data_parallel_size = world_size / (ETP * EP * PP) is valid + expert_data_parallel_size = WORLD_SIZE // (TP_SIZE * EP_SIZE * PP_SIZE) + if expert_data_parallel_size < 1: + raise ValueError( + f'Not enough GPUs ({WORLD_SIZE}) for expert parallelism: ' + f'TP={TP_SIZE}, PP={PP_SIZE}, EP={EP_SIZE}. ' + f'Required at least: {TP_SIZE * EP_SIZE * PP_SIZE}') + + logger.info( + f'Parallelism config: TP={TP_SIZE}, PP={PP_SIZE}, CP={CP_SIZE}, EP={EP_SIZE}, DP={DP_SIZE}' + ) + logger.info( + f'Expert data parallel size: {expert_data_parallel_size}' + ) + + # Device mesh: Match Megatron's order "tp-cp-ep-dp-pp" from innermost to outermost + # Note: EP is not a separate dimension in the device mesh because: + # 1. Megatron handles EP internally in initialize_model_parallel() + # 2. For non-expert layers, DP = world_size / (TP * PP * CP) + # 3. For expert layers, expert_data_parallel_size = world_size / (ETP * EP * PP) + # The device mesh is used by twinkle for data sharding, which follows DP_SIZE + device_mesh = DeviceMesh( + device_type='cuda', + mesh=np.arange(WORLD_SIZE).reshape(PP_SIZE, DP_SIZE, CP_SIZE, TP_SIZE), + mesh_dim_names=('pp', 'dp', 'cp', 'tp'), + ) + + # Device group name - used as remote_group in Ray mode + GROUP_NAME = 'model' + + device_group = [ + DeviceGroup( + name=GROUP_NAME, + ranks=list(range(WORLD_SIZE)), + device_type=Platform.get_platform().device_prefix(), + ) + ] + + twinkle.initialize( + mode=args.mode, + nproc_per_node=WORLD_SIZE, + groups=device_group, + global_device_mesh=device_mesh, + lazy_collect=False, + ) + + # Smaller batch size for MoE models (larger memory footprint) + batch_size = 2 + + # In Ray mode, pass remote_group and device_mesh + if args.mode == 'ray': + dataloader = DataLoader( + dataset=create_dataset, + batch_size=batch_size, + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + expert_model_parallel_size=EP_SIZE, + sequence_parallel=args.sequence_parallel, + mixed_precision='bf16', + recompute_granularity='selective', + remote_group=GROUP_NAME, + device_mesh=device_mesh, + ) + else: + dataloader = DataLoader(dataset=create_dataset, batch_size=batch_size) + model = MegatronModel( + pretrained_model_name_or_path=args.model, + tensor_model_parallel_size=TP_SIZE, + pipeline_model_parallel_size=PP_SIZE, + context_parallel_size=CP_SIZE, + expert_model_parallel_size=EP_SIZE, + sequence_parallel=args.sequence_parallel, + mixed_precision='bf16', + recompute_granularity='selective', + ) + + # LoRA config - target all linear layers in MoE (including experts) + lora_config = LoraConfig( + target_modules='all-linear', + r=8, + lora_alpha=8, + lora_dropout=0.0, + ) + adapter_name = 'lora' + model.add_adapter_to_model(adapter_name, + lora_config, + gradient_accumulation_steps=GAS) + model.set_template('Qwen3Template', adapter_name=adapter_name) + model.set_processor(InputProcessor, + padding_side='right', + adapter_name=adapter_name) + model.set_loss(MegatronCrossEntropyLoss, adapter_name=adapter_name) + model.set_optimizer(AdamW, lr=1e-4, adapter_name=adapter_name) + model.set_lr_scheduler(LinearLR, adapter_name=adapter_name) + + logger.info(get_device_placement()) + logger.info(model.get_train_configs(adapter_name=adapter_name)) + + for step, batch in enumerate(dataloader): + output = model.forward_backward(inputs=batch, + adapter_name=adapter_name) + if step % GAS == 0: + logger.info(f'Step {step // GAS}, loss: {output}') + model.clip_grad_norm(1.0, adapter_name=adapter_name) + model.step(adapter_name=adapter_name) + model.zero_grad(adapter_name=adapter_name) + model.lr_step(adapter_name=adapter_name) + if step > 0 and step % (100 * GAS) == 0: + model.save('./output/megatron_moe_lora', adapter_name=adapter_name) + # Early stop for testing + if args.max_steps and step >= args.max_steps * GAS: + logger.info(f'Reached max_steps ({args.max_steps}), stopping.') + break + model.save('./output/megatron_moe_lora', adapter_name=adapter_name) + logger.info('Training completed!') + + +def cleanup(): + """Clean up distributed resources.""" + import torch.distributed as dist + try: + if dist.is_initialized(): + dist.barrier() + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + mpu.destroy_model_parallel() + except Exception as e: + logger.warning(f"Error during cleanup: {e}") + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == '__main__': + try: + train() + finally: + cleanup() diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 18bbfa12..d157569e 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -244,7 +244,7 @@ def render_mesh_grid(mesh_array, dim_names): lines.extend(section_bottom()) lines.append("") - return "\n".join(lines) + return "\n" + "\n".join(lines) def _get_workers(workers, execute): @@ -484,7 +484,8 @@ def __next__(_self): def remote_function(dispatch: Union[Literal['slice', 'all'], Callable] = 'slice', execute: Literal['first', 'peer', 'all'] = 'all', - collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first'], Callable] = 'none'): + collect: Union[Literal['none', 'flatten', 'mean', 'sum', 'first'], Callable] = 'none', + sync: bool = False): """Patch each method called from remote(which class should be decorated with `remote_class`) with this decorator. Args: @@ -503,6 +504,8 @@ def remote_function(dispatch: Union[Literal['slice', 'all'], Callable] = 'slice' 'sum': Return the sum value of all processes 'first': Return the first worker's result but executed in each process, usually works for scenarios of all-gather. Callable: A callable that handles the collection + sync: If True, use synchronous execution (execute_all_sync) instead of async. + Required for methods with NCCL collective operations (e.g., Megatron forward_backward). """ def decorator(func: Callable[..., T1]) -> Callable[..., T1]: @@ -522,7 +525,8 @@ def wrapper(self, *args, **kwargs) -> T1: from ._ray import RayHelper _workers_and_args = _dispatch_args(_get_workers(self._actors, execute), dispatch, execute, device_mesh, args, kwargs) - result = RayHelper.execute_all_async(func.__name__, _workers_and_args) + execute_method = RayHelper.execute_all_async if not sync else RayHelper.execute_all_sync + result = execute_method(func.__name__, _workers_and_args) result_func = RayHelper.do_get_and_collect_func(_collect_func, collect, result) lazy_collect = _lazy_collect if func.__name__ == '__iter__': @@ -549,6 +553,7 @@ def wrapper(self, *args, **kwargs) -> T1: wrapper._collect = collect wrapper._dispatch = dispatch wrapper._lazy_collect = _lazy_collect + wrapper._sync = sync return wrapper return decorator diff --git a/src/twinkle/loss/__init__.py b/src/twinkle/loss/__init__.py index 8afecfae..c501225e 100644 --- a/src/twinkle/loss/__init__.py +++ b/src/twinkle/loss/__init__.py @@ -11,6 +11,7 @@ from .listwise_reranker import ListwiseRerankerLoss from .listwise_generative_reranker import ListwiseGenerativeRerankerLoss from .grpo import GRPOLoss +from .vocab_parallel_cross_entropy import MegatronCrossEntropyLoss from .base import Loss torch_loss_mapping = { @@ -26,4 +27,5 @@ 'listwise_reranker': ListwiseRerankerLoss, 'listwise_generative_reranker': ListwiseGenerativeRerankerLoss, 'grpo': GRPOLoss, -} \ No newline at end of file + 'megatron_cross_entropy': MegatronCrossEntropyLoss, +} diff --git a/src/twinkle/loss/vocab_parallel_cross_entropy.py b/src/twinkle/loss/vocab_parallel_cross_entropy.py new file mode 100644 index 00000000..1c035c4f --- /dev/null +++ b/src/twinkle/loss/vocab_parallel_cross_entropy.py @@ -0,0 +1,46 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Vocab-parallel cross entropy loss for Megatron backend with Tensor Parallelism.""" +import torch + +from .base import Loss + + +class VocabParallelCrossEntropyLoss(Loss): + """Vocab-parallel cross entropy loss for Megatron training with TP > 1. + + This loss uses Megatron's tensor_parallel.vocab_parallel_cross_entropy to + correctly compute cross entropy when vocabulary is sharded across TP ranks. + + NOTE: Labels are expected to be pre-shifted by the template (using np.roll). + This loss does NOT perform additional shifting. + + Args: + ignore_index: The label value to ignore when computing loss. Default: -100. + """ + + def __init__(self, ignore_index: int = -100): + super().__init__() + self.ignore_index = ignore_index + + def __call__(self, inputs, outputs, **kwargs): + from megatron.core import tensor_parallel + + logits = outputs['logits'] + labels = inputs['labels'] + + # Transpose: [batch, seq, vocab] -> [seq, batch, vocab] + logits_sbv = logits.transpose(0, 1).contiguous() + labels_sb = labels.transpose(0, 1).contiguous() + + # Compute vocab-parallel cross entropy + per_token_loss = tensor_parallel.vocab_parallel_cross_entropy(logits_sbv, labels_sb) + per_token_loss = per_token_loss.transpose(0, 1).contiguous() + + # Apply loss mask + loss_mask = (labels != self.ignore_index).float() + loss = (per_token_loss * loss_mask).sum() / loss_mask.sum().clamp(min=1) + + return loss + + +MegatronCrossEntropyLoss = VocabParallelCrossEntropyLoss diff --git a/src/twinkle/megatron/__init__.py b/src/twinkle/megatron/__init__.py new file mode 100644 index 00000000..c8b0c3d9 --- /dev/null +++ b/src/twinkle/megatron/__init__.py @@ -0,0 +1,72 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core integration for twinkle training framework. + +This module provides independent implementation for Megatron support, +""" + +from .model import (BridgeConfig, LazyTensor, MegatronModelInitializer, + Qwen3ModelMeta, SafetensorLoader, StreamingSafetensorSaver, + TwinkleBridgeAdapter, TwinkleGPTBridge, + create_megatron_args, get_model_default_config, + initialize_megatron_model, is_last_rank, + load_hf_weights_to_megatron, mock_megatron_args, + restore_megatron_args, set_megatron_args) +from .tuners import LoraParallelLinear, dispatch_megatron + +from .model import deep_getattr as bridge_deep_getattr # Bridge classes; Helper functions; Avoid conflict with utils.deep_getattr; Legacy compatibility; Initializer; Qwen3 support +from .utils import ( # Layer finding; Model preparation; Config conversion; Utilities; Multi-tenant support; Training state + MegatronTrainerState, TenantProcessGroupManager, convert_hf_config, + deep_getattr, find_all_linears, find_embedding, find_router, + forward_step_helper, get_model_parameter_info, get_padding_to, + get_target_modules, get_tenant_manager, patch_deepcopy, prepare_lora_model, + prepare_mcore_model, set_linear_is_expert, tuners_sharded_state_dict) + +__all__ = [ + # Tuners + 'LoraParallelLinear', + 'dispatch_megatron', + # Layer finding + 'find_all_linears', + 'find_router', + 'find_embedding', + 'get_target_modules', + 'set_linear_is_expert', + # Model preparation + 'prepare_mcore_model', + 'prepare_lora_model', + # Config conversion + 'convert_hf_config', + # Utilities + 'get_model_parameter_info', + 'get_padding_to', + 'patch_deepcopy', + 'tuners_sharded_state_dict', + 'forward_step_helper', + 'deep_getattr', + # Multi-tenant support + 'TenantProcessGroupManager', + 'get_tenant_manager', + # Training state + 'MegatronTrainerState', + # Bridge classes + 'TwinkleBridgeAdapter', + 'TwinkleGPTBridge', + 'BridgeConfig', + 'SafetensorLoader', + 'StreamingSafetensorSaver', + 'LazyTensor', + # Helper functions + 'load_hf_weights_to_megatron', + 'is_last_rank', + # Legacy compatibility + 'create_megatron_args', + 'set_megatron_args', + 'restore_megatron_args', + 'mock_megatron_args', + # Initializer + 'MegatronModelInitializer', + 'initialize_megatron_model', + # Qwen3 support + 'Qwen3ModelMeta', + 'get_model_default_config', +] diff --git a/src/twinkle/megatron/model/__init__.py b/src/twinkle/megatron/model/__init__.py new file mode 100644 index 00000000..e3bc5599 --- /dev/null +++ b/src/twinkle/megatron/model/__init__.py @@ -0,0 +1,46 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron model initialization and weight conversion. + +This module provides independent implementation for weight loading/saving, +and multi-tenant model wrapper for LoRA training. +""" + +from .bridge import ( # Main classes; Helper functions; Legacy compatibility + BridgeConfig, LazyTensor, SafetensorLoader, StreamingSafetensorSaver, + TwinkleBridgeAdapter, TwinkleBridgeInitializer, TwinkleGPTBridge, + create_megatron_args, deep_getattr, is_last_rank, + load_hf_weights_to_megatron, mock_megatron_args, restore_megatron_args, + set_megatron_args) +from .initializer import MegatronModelInitializer, initialize_megatron_model +from .multi_tenant_megatron import (MegatronMultiAdapter, + MultiTenantMegatronModel) +from .qwen3 import Qwen3ModelMeta, get_model_default_config + +__all__ = [ + # Bridge classes + 'TwinkleBridgeAdapter', + 'TwinkleBridgeInitializer', + 'TwinkleGPTBridge', + 'BridgeConfig', + 'SafetensorLoader', + 'StreamingSafetensorSaver', + 'LazyTensor', + # Helper functions + 'deep_getattr', + 'is_last_rank', + 'load_hf_weights_to_megatron', + # Legacy compatibility + 'create_megatron_args', + 'set_megatron_args', + 'restore_megatron_args', + 'mock_megatron_args', + # Initializer + 'MegatronModelInitializer', + 'initialize_megatron_model', + # Model metadata + 'Qwen3ModelMeta', + 'get_model_default_config', + # Multi-tenant + 'MultiTenantMegatronModel', + 'MegatronMultiAdapter', +] diff --git a/src/twinkle/megatron/model/bridge.py b/src/twinkle/megatron/model/bridge.py new file mode 100644 index 00000000..4aab500e --- /dev/null +++ b/src/twinkle/megatron/model/bridge.py @@ -0,0 +1,2112 @@ +# Copyright (c) twinkle authors. All rights reserved. +# GPT Bridge for HuggingFace to Megatron-Core weight conversion. +import glob +import json +import math +import os +from copy import copy +from dataclasses import dataclass, field +from types import SimpleNamespace +from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from twinkle.hub import HubOperation + +try: + from megatron.core import parallel_state as mpu + MEGATRON_AVAILABLE = True +except ImportError: + MEGATRON_AVAILABLE = False + mpu = None + +try: + from safetensors import safe_open + from safetensors.torch import save_file + SAFETENSORS_AVAILABLE = True +except ImportError: + SAFETENSORS_AVAILABLE = False + + +def deep_getattr(obj, attr: str, default=None): + """Get nested attribute from object using dot notation.""" + try: + for key in attr.split('.'): + obj = getattr(obj, key) + return obj + except AttributeError: + return default + + +def is_last_rank() -> bool: + """Check if current process is the last rank for writing. + + For DP > 1, we want only DP rank 0 to write to avoid conflicts. + For PP, we want the last PP stage. + For TP, all TP ranks participate in gather, but only one writes. + """ + if not dist.is_initialized(): + return True + + try: + from megatron.core import parallel_state as mpu + if mpu.is_initialized(): + # Only DP rank 0 writes + dp_rank = mpu.get_data_parallel_rank() + if dp_rank != 0: + return False + # For PP, only last stage needs to write certain weights + # (handled separately in export_weights) + return True + except (ImportError, AssertionError): + pass + + return dist.get_rank() == dist.get_world_size() - 1 + + +class LazyTensor: + """Lazy tensor wrapper for deferred loading.""" + def __init__(self, loader, key: str): + self._loader = loader + self._key = key + + def load(self) -> torch.Tensor: + """Load the tensor.""" + return self._loader.get_tensor(self._key) + + +class SafetensorLoader: + """Lazy loader for safetensor files.""" + def __init__(self, model_dir: str, is_peft_format: bool = False): + self.model_dir = model_dir + self.is_peft_format = is_peft_format + self._handles = {} + self._index = None + self._key_to_file = {} + self._load_index() + + def _load_index(self): + """Load safetensor index file if exists.""" + # Try adapter format first for PEFT + if self.is_peft_format: + adapter_file = os.path.join(self.model_dir, + 'adapter_model.safetensors') + if os.path.exists(adapter_file): + handle = safe_open(adapter_file, framework='pt', device='cpu') + for key in handle.keys(): + self._key_to_file[key] = adapter_file + self._handles[adapter_file] = handle + return + + # Standard index file + index_file = os.path.join(self.model_dir, + 'model.safetensors.index.json') + if os.path.exists(index_file): + with open(index_file, 'r') as f: + self._index = json.load(f) + for key, filename in self._index['weight_map'].items(): + self._key_to_file[key] = os.path.join(self.model_dir, filename) + else: + # Single file model + single_file = os.path.join(self.model_dir, 'model.safetensors') + if os.path.exists(single_file): + handle = safe_open(single_file, framework='pt', device='cpu') + for key in handle.keys(): + self._key_to_file[key] = single_file + self._handles[single_file] = handle + else: + # Try to find any safetensor file + files = glob.glob(os.path.join(self.model_dir, + '*.safetensors')) + for filepath in files: + handle = safe_open(filepath, framework='pt', device='cpu') + for key in handle.keys(): + self._key_to_file[key] = filepath + self._handles[filepath] = handle + + def _get_handle(self, filepath: str): + """Get or create file handle.""" + if filepath not in self._handles: + self._handles[filepath] = safe_open(filepath, + framework='pt', + device='cpu') + return self._handles[filepath] + + def get_tensor(self, key: str) -> torch.Tensor: + """Load a single tensor.""" + filepath = self._key_to_file.get(key) + if filepath is None: + raise KeyError(f'Tensor key not found: {key}') + handle = self._get_handle(filepath) + return handle.get_tensor(key) + + def get_lazy(self, key: str) -> LazyTensor: + """Get a lazy tensor reference.""" + if key not in self._key_to_file: + raise KeyError(f'Tensor key not found: {key}') + return LazyTensor(self, key) + + def get_state_dict(self) -> Dict[str, LazyTensor]: + """Get lazy state dict.""" + return {key: LazyTensor(self, key) for key in self._key_to_file} + + def keys(self) -> List[str]: + """Get all tensor keys.""" + return list(self._key_to_file.keys()) + + def __contains__(self, key: str) -> bool: + return key in self._key_to_file + + def close(self): + """Close all file handles.""" + self._handles.clear() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class StreamingSafetensorSaver: + """Streaming saver for safetensor files.""" + def __init__(self, + save_dir: str, + max_shard_size: str = '5GB', + is_peft_format: bool = False): + self.save_dir = save_dir + self.is_peft_format = is_peft_format + os.makedirs(save_dir, exist_ok=True) + + # Parse max shard size + size_str = max_shard_size.upper() + if size_str.endswith('GB'): + self.max_shard_bytes = int(float(size_str[:-2]) * 1024**3) + elif size_str.endswith('MB'): + self.max_shard_bytes = int(float(size_str[:-2]) * 1024**2) + else: + self.max_shard_bytes = int(size_str) + + self.current_shard = {} + self.current_shard_size = 0 + self.shard_idx = 1 + self.weight_map = {} + + def add_tensor(self, key: str, tensor: torch.Tensor): + """Add tensor to the current shard.""" + if tensor is None: + return + + tensor_size = tensor.numel() * tensor.element_size() + + # Flush if needed + if self.current_shard_size + tensor_size > self.max_shard_bytes and self.current_shard: + self._flush_shard() + + self.current_shard[key] = tensor.contiguous() + self.current_shard_size += tensor_size + + def _flush_shard(self): + """Flush current shard to disk.""" + if not self.current_shard: + return + + if self.is_peft_format: + filename = 'adapter_model.safetensors' + else: + filename = f'model-{self.shard_idx:05d}-of-XXXXX.safetensors' + + filepath = os.path.join(self.save_dir, filename) + save_file(self.current_shard, filepath) + + for key in self.current_shard: + self.weight_map[key] = filename + + self.current_shard = {} + self.current_shard_size = 0 + self.shard_idx += 1 + + def finalize(self): + """Finalize and write index.""" + self._flush_shard() + + if self.is_peft_format: + return # PEFT format doesn't need index + + # Fix shard filenames + total_shards = self.shard_idx - 1 + if total_shards == 0: + return + + for old_name in list(self.weight_map.values()): + new_name = old_name.replace('XXXXX', f'{total_shards:05d}') + if old_name != new_name: + old_path = os.path.join(self.save_dir, old_name) + new_path = os.path.join(self.save_dir, new_name) + if os.path.exists(old_path): + os.rename(old_path, new_path) + for key in self.weight_map: + if self.weight_map[key] == old_name: + self.weight_map[key] = new_name + + if total_shards > 1: + index = { + 'metadata': { + 'total_size': + sum(t.numel() * t.element_size() + for t in self.current_shard.values()) + }, + 'weight_map': self.weight_map + } + with open( + os.path.join(self.save_dir, + 'model.safetensors.index.json'), 'w') as f: + json.dump(index, f, indent=2) + + +@dataclass +class BridgeConfig: + """Configuration for GPTBridge.""" + # Parallelism + tp_size: int = 1 + pp_size: int = 1 + ep_size: int = 1 + etp_size: int = 1 + + # Model architecture + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_key_value_heads: int = 32 + num_layers: int = 32 + vocab_size: int = 32000 + padded_vocab_size: int = 32000 + intermediate_size: int = 11008 + kv_channels: int = None # head_dim, if None will be computed from hidden_size // num_attention_heads + + # Options + add_qkv_bias: bool = False + add_bias_linear: bool = False + qk_layernorm: bool = False + tie_word_embeddings: bool = False + + # MoE + num_experts: int = 0 + num_experts_per_tok: int = 2 + shared_expert_intermediate_size: int = 0 + + model_type: str = 'qwen2' + max_shard_size: str = '5GB' + + @classmethod + def from_hf_config( + cls, + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + padded_vocab_size: Optional[int] = None, + ) -> 'BridgeConfig': + """Create BridgeConfig from HuggingFace config.""" + vocab_size = getattr(hf_config, 'vocab_size', 32000) + if padded_vocab_size is None: + padded_vocab_size = vocab_size + # Pad to multiple of 64 for efficiency + if padded_vocab_size % 64 != 0: + padded_vocab_size = ((padded_vocab_size // 64) + 1) * 64 + + num_attention_heads = getattr(hf_config, 'num_attention_heads', 32) + num_key_value_heads = getattr(hf_config, 'num_key_value_heads', + num_attention_heads) + + # MoE config + num_experts = getattr(hf_config, 'num_experts', 0) or \ + getattr(hf_config, 'n_routed_experts', 0) or \ + getattr(hf_config, 'num_local_experts', 0) + num_experts_per_tok = getattr(hf_config, 'num_experts_per_tok', 2) or \ + getattr(hf_config, 'moe_topk', 2) + shared_expert_size = getattr(hf_config, + 'shared_expert_intermediate_size', 0) + + # Determine QKV bias setting + # Qwen2 has attention bias by default (hardcoded in transformers), + # but config doesn't have 'attention_bias' field + model_type = getattr(hf_config, 'model_type', 'qwen2') + if hasattr(hf_config, 'attention_bias'): + add_qkv_bias = hf_config.attention_bias + elif model_type in ('qwen2', 'qwen2_5'): + # Qwen2/Qwen2.5 uses bias=True for Q, K, V projections + add_qkv_bias = True + else: + add_qkv_bias = False + + # Determine QK layernorm setting + # Qwen3 uses QK layernorm but doesn't have explicit config attribute + qk_layernorm = getattr(hf_config, 'qk_layernorm', False) or \ + getattr(hf_config, 'use_qk_norm', False) + if not qk_layernorm and model_type in ('qwen3', 'qwen3_moe'): + # Qwen3 (dense and MoE) always uses QK layernorm (q_norm, k_norm weights) + qk_layernorm = True + + # Determine kv_channels (head_dim) - Qwen3 has explicit head_dim + kv_channels = getattr(hf_config, 'head_dim', None) + + return cls( + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + etp_size=tp_size, + hidden_size=getattr(hf_config, 'hidden_size', 4096), + num_attention_heads=num_attention_heads, + num_key_value_heads=num_key_value_heads, + num_layers=getattr(hf_config, 'num_hidden_layers', 32), + vocab_size=vocab_size, + padded_vocab_size=padded_vocab_size, + intermediate_size=getattr(hf_config, 'intermediate_size', 11008), + add_qkv_bias=add_qkv_bias, + add_bias_linear=getattr(hf_config, 'mlp_bias', False), + qk_layernorm=qk_layernorm, + tie_word_embeddings=getattr(hf_config, 'tie_word_embeddings', + False), + num_experts=num_experts, + num_experts_per_tok=num_experts_per_tok, + shared_expert_intermediate_size=shared_expert_size, + model_type=model_type, + kv_channels=kv_channels, # Explicit head_dim for Qwen3 + ) + + +class TwinkleGPTBridge: + """Bridge for converting weights between HuggingFace and Megatron-Core formats. + + Supports Qwen2.5 / Qwen3 model families. + """ + + # HuggingFace model structure constants (Qwen2/Qwen3 compatible) + HF_LAYERS_PREFIX = 'model.layers' + HF_EMBED_KEY = 'model.embed_tokens.weight' + HF_FINAL_LAYERNORM_KEY = 'model.norm.weight' + HF_LM_HEAD_KEY = 'lm_head.weight' + + def __init__(self, + config: BridgeConfig, + hf_config: Any = None, + disable_tqdm: bool = False): + """Initialize the bridge. + + Args: + config: Bridge configuration. + hf_config: HuggingFace model config (for reference). + disable_tqdm: Whether to disable progress bar. + """ + self.config = config + self.hf_config = hf_config + self.disable_tqdm = disable_tqdm or not is_last_rank() + + # Parallel state + self.tp_size = config.tp_size + self.pp_size = config.pp_size + self.ep_size = config.ep_size + self.etp_size = config.etp_size + + # Get parallel ranks + if MEGATRON_AVAILABLE and mpu.is_initialized(): + self.tp_rank = mpu.get_tensor_model_parallel_rank() + self.pp_rank = mpu.get_pipeline_model_parallel_rank() + self.tp_group = mpu.get_tensor_model_parallel_group() + self.pp_group = mpu.get_pipeline_model_parallel_group() + try: + self.ep_rank = mpu.get_expert_model_parallel_rank() + self.ep_group = mpu.get_expert_model_parallel_group() + self.etp_rank = mpu.get_expert_tensor_parallel_rank() + self.etp_group = mpu.get_expert_tensor_parallel_group() + except (AttributeError, AssertionError): + self.ep_rank = 0 + self.ep_group = None + self.etp_rank = 0 + self.etp_group = None + else: + self.tp_rank = 0 + self.pp_rank = 0 + self.tp_group = None + self.pp_group = None + self.ep_rank = 0 + self.ep_group = None + self.etp_rank = 0 + self.etp_group = None + + # PEFT tracking + self._is_peft_format = False + self._adapter_name = 'default' + self._peft_target_modules: Set[str] = set() + self._peft_modules_to_save: Set[str] = set() + self._target_device = None + self._only_last_rank = False + + def _get_tp_split_dim(self, mg_key: Optional[str]) -> Optional[int]: + """Determine which dimension to split for tensor parallelism.""" + if mg_key is None: + return None + + # ColumnParallel (split output dim) + dim0_keys = { + 'word_embeddings', + 'linear_qkv', + 'output_layer', + 'linear_q_proj', + 'linear_q_up_proj', + 'linear_kv_up_proj', + 'eh_proj', # MTP + } + # RowParallel (split input dim) + dim1_keys = {'linear_proj', 'linear_fc2'} + + # Handle LoRA keys + if 'lora_A' not in mg_key and 'lora_B' not in mg_key: + key_parts = mg_key.rsplit('.', 2) + if len(key_parts) >= 2: + key = key_parts[-2] + suffix = key_parts[-1] + + if suffix == 'layer_norm_weight': + return None + elif key in dim0_keys: + return 0 + elif key in {'linear_fc1'} and suffix != 'bias': + return 1 + elif key in dim1_keys and suffix != 'bias': + return 1 + else: + # LoRA weights + key_parts = mg_key.rsplit('.', 3) + if len(key_parts) >= 2: + key = key_parts[0] + lora_name = key_parts[1] if len(key_parts) > 1 else '' + if lora_name == 'lora_A': + if key in dim1_keys: + return 1 + elif lora_name == 'lora_B': + if key in dim0_keys: + return 0 + elif key == 'linear_fc1': + return 1 + + return None + + def _split_tp(self, + tensor: torch.Tensor, + tp_dim: Optional[int], + is_expert: bool = False) -> torch.Tensor: + """Split tensor for tensor parallelism.""" + tp_size = self.etp_size if is_expert else self.tp_size + tp_rank = self.etp_rank if is_expert else self.tp_rank + + if tp_dim is None or tp_size <= 1: + return tensor + return tensor.chunk(tp_size, dim=tp_dim)[tp_rank] + + def _all_gather_tp(self, + tensor: Optional[torch.Tensor], + tp_dim: Optional[int], + is_expert: bool = False) -> Optional[torch.Tensor]: + """All-gather tensor across tensor parallel group.""" + if tensor is None: + return None + + tensor = tensor.to('cuda') + tp_size = self.etp_size if is_expert else self.tp_size + tp_group = self.etp_group if is_expert else self.tp_group + + if tp_dim is None or tp_size <= 1: + return tensor + + if tp_dim == 0: + tensor_shape = list(tensor.shape) + tensor_shape[0] *= tp_size + output = tensor.new_empty(tensor_shape) + dist.all_gather_into_tensor(output, tensor, group=tp_group) + return output + else: + output = [torch.empty_like(tensor) for _ in range(tp_size)] + dist.all_gather(output, tensor, group=tp_group) + return torch.cat(output, dim=tp_dim) + + def _set_weight( + self, + mg_param: Union[torch.Tensor, nn.Parameter, List], + hf_weight: torch.Tensor, + mg_key: str, + is_expert: bool = False, + ): + """Set weight from HuggingFace to Megatron parameter.""" + tp_dim = self._get_tp_split_dim(mg_key) + tensor = self._split_tp(hf_weight, tp_dim, is_expert) + + if not isinstance(mg_param, (list, tuple)): + mg_param = [mg_param] + + tensor_list = tensor.chunk(len(mg_param), dim=0) + for i, param in enumerate(mg_param): + t = tensor_list[i].reshape(*param.shape) + param.data.copy_(t) + + def _get_weight( + self, + mg_weight: Optional[Union[torch.Tensor, List[torch.Tensor]]], + mg_key: Optional[str], + is_expert: bool = False, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """Get weight from Megatron parameter, gathered across TP.""" + if mg_weight is None: + return None, None + + tensor = mg_weight + if not isinstance(tensor, (list, tuple)): + tensor = [tensor] + + tensor = torch.cat(tensor, dim=0) + tp_dim = self._get_tp_split_dim(mg_key) + tensor = self._all_gather_tp(tensor, tp_dim, is_expert) + + if self._target_device is not None and tensor is not None: + tensor = tensor.to(device=self._target_device) + + if self._only_last_rank and not is_last_rank(): + return None, None + + return tensor, None + + # ========================================================================= + # Weight Loading Methods + # ========================================================================= + + def _load_embedding(self, mg_model, loader: SafetensorLoader): + """Load embedding weights.""" + embed_module = deep_getattr(mg_model, 'embedding.word_embeddings') + if embed_module is None: + return + + hf_weight = loader.get_tensor(self.HF_EMBED_KEY) + + # Pad vocabulary if needed + if hf_weight.shape[0] < self.config.padded_vocab_size: + hf_weight = F.pad( + hf_weight, + (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0])) + + self._set_weight(embed_module.weight, hf_weight, + 'word_embeddings.weight') + + def _load_output_layer(self, mg_model, loader: SafetensorLoader): + """Load output layer (lm_head) weights.""" + output_module = deep_getattr(mg_model, 'output_layer') + if output_module is None or output_module.weight is None: + return + + # Check if weights are tied + if self.config.tie_word_embeddings: + hf_weight = loader.get_tensor(self.HF_EMBED_KEY) + else: + hf_weight = loader.get_tensor(self.HF_LM_HEAD_KEY) + + # Pad vocabulary if needed + if hf_weight.shape[0] < self.config.padded_vocab_size: + hf_weight = F.pad( + hf_weight, + (0, 0, 0, self.config.padded_vocab_size - hf_weight.shape[0])) + + self._set_weight(output_module.weight, hf_weight, + 'output_layer.weight') + + def _load_final_layernorm(self, mg_model, loader: SafetensorLoader): + """Load final layer norm weights.""" + ln_module = deep_getattr(mg_model, 'decoder.final_layernorm') + if ln_module is None: + return + + hf_weight = loader.get_tensor(self.HF_FINAL_LAYERNORM_KEY) + ln_module.weight.data.copy_(hf_weight) + + def _load_attention(self, mg_layer, loader: SafetensorLoader, + layer_idx: int): + """Load attention layer weights.""" + mg_attn = mg_layer.self_attention + prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.self_attn.' + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + # Use kv_channels (head_dim) from config if available (for Qwen3 etc.) + head_dim = getattr(self.config, 'kv_channels', + hidden_size // num_heads) + heads_per_group = num_heads // num_kv_heads + + # Load Q, K, V weights and merge into linear_qkv + q_weight = loader.get_tensor(f'{prefix}q_proj.weight') + k_weight = loader.get_tensor(f'{prefix}k_proj.weight') + v_weight = loader.get_tensor(f'{prefix}v_proj.weight') + + # Infer head_dim from actual weight shapes if needed + actual_kv_dim = k_weight.shape[0] // num_kv_heads + if actual_kv_dim != head_dim: + head_dim = actual_kv_dim + + # Reshape for GQA + q_weight = q_weight.reshape(num_kv_heads, heads_per_group * head_dim, + hidden_size) + k_weight = k_weight.reshape(num_kv_heads, head_dim, hidden_size) + v_weight = v_weight.reshape(num_kv_heads, head_dim, hidden_size) + + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=1) + qkv_weight = qkv_weight.reshape(-1, hidden_size) + + self._set_weight(mg_attn.linear_qkv.weight, qkv_weight, + 'linear_qkv.weight') + + # Load O projection + o_weight = loader.get_tensor(f'{prefix}o_proj.weight') + self._set_weight(mg_attn.linear_proj.weight, o_weight, + 'linear_proj.weight') + + # Load biases if present + if self.config.add_qkv_bias: + try: + q_bias = loader.get_tensor(f'{prefix}q_proj.bias') + k_bias = loader.get_tensor(f'{prefix}k_proj.bias') + v_bias = loader.get_tensor(f'{prefix}v_proj.bias') + + # Infer head_dim from actual bias shapes if needed + actual_bias_head_dim = k_bias.shape[0] // num_kv_heads + + q_bias = q_bias.reshape(num_kv_heads, + heads_per_group * actual_bias_head_dim) + k_bias = k_bias.reshape(num_kv_heads, actual_bias_head_dim) + v_bias = v_bias.reshape(num_kv_heads, actual_bias_head_dim) + + qkv_bias = torch.cat([q_bias, k_bias, v_bias], + dim=1).reshape(-1) + self._set_weight(mg_attn.linear_qkv.bias, qkv_bias, + 'linear_qkv.bias') + except KeyError: + pass + + # Load input layernorm (may be fused) + ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.input_layernorm.weight' + ln_weight = loader.get_tensor(ln_key) + + ln_param = deep_getattr(mg_attn, 'linear_qkv.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) + else: + ln_module = deep_getattr(mg_layer, 'input_layernorm') + if ln_module is not None: + ln_module.weight.data.copy_(ln_weight) + + # QK layernorm (Qwen3) + if self.config.qk_layernorm: + try: + q_norm = loader.get_tensor(f'{prefix}q_norm.weight') + k_norm = loader.get_tensor(f'{prefix}k_norm.weight') + q_ln = deep_getattr(mg_attn, 'q_layernorm') + k_ln = deep_getattr(mg_attn, 'k_layernorm') + if q_ln is not None: + q_ln.weight.data.copy_(q_norm) + if k_ln is not None: + k_ln.weight.data.copy_(k_norm) + except KeyError: + pass + + def _load_mlp(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load MLP layer weights.""" + mg_mlp = mg_layer.mlp + prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' + + # Check if gate_up_proj is fused + try: + gate_weight = loader.get_tensor(f'{prefix}gate_proj.weight') + up_weight = loader.get_tensor(f'{prefix}up_proj.weight') + + # Stack gate and up projections (shape: [2, intermediate, hidden]) + fc1_weight = torch.stack([gate_weight, up_weight], dim=0) + self._set_weight(mg_mlp.linear_fc1.weight, fc1_weight, + 'linear_fc1.weight') + except KeyError: + # Try gate_up_proj (fused) + try: + gate_up_weight = loader.get_tensor( + f'{prefix}gate_up_proj.weight') + gate_up_weight = gate_up_weight.view(2, -1, + gate_up_weight.shape[-1]) + self._set_weight(mg_mlp.linear_fc1.weight, gate_up_weight, + 'linear_fc1.weight') + except KeyError: + pass + + # Load down projection + try: + down_weight = loader.get_tensor(f'{prefix}down_proj.weight') + self._set_weight(mg_mlp.linear_fc2.weight, down_weight, + 'linear_fc2.weight') + except KeyError: + pass + + # Load post attention layernorm + ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' + try: + ln_weight = loader.get_tensor(ln_key) + + ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) + else: + ln_module = deep_getattr(mg_layer, 'pre_mlp_layernorm') + if ln_module is not None: + ln_module.weight.data.copy_(ln_weight) + except KeyError: + pass + + def _load_moe(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load MoE layer weights. + + Handles Expert Parallel (EP) sharding - each EP rank loads only its + assigned subset of experts based on ep_rank and ep_size. + + For EP=2 with 128 experts: + - EP rank 0 loads experts 0-63 + - EP rank 1 loads experts 64-127 + """ + mg_mlp = mg_layer.mlp + prefix = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.mlp.' + + # Load router (replicated across all ranks) + try: + router_key = None + for key in ['gate.weight', 'router.weight', 'gate.wg.weight']: + full_key = f'{prefix}{key}' + if full_key in loader: + router_key = full_key + break + + if router_key: + router_weight = loader.get_tensor(router_key) + router_module = deep_getattr(mg_mlp, 'router') + if router_module is not None and hasattr( + router_module, 'weight'): + router_module.weight.data.copy_(router_weight) + + # Load expert bias if present (for sigmoid routers like Qwen3) + for bias_key in [ + 'gate.e_score_correction_bias', + 'moe_statics.e_score_correction_bias' + ]: + full_bias_key = f'{prefix}{bias_key}' + if full_bias_key in loader: + try: + expert_bias = loader.get_tensor(full_bias_key) + if router_module is not None and hasattr( + router_module, 'expert_bias'): + router_module.expert_bias.data.copy_(expert_bias) + break + except KeyError: + continue + except KeyError: + pass + + # Load shared experts if present + if self.config.shared_expert_intermediate_size > 0: + for shared_key in [ + 'shared_expert', 'shared_experts', 'shared_mlp' + ]: + try: + gate_weight = loader.get_tensor( + f'{prefix}{shared_key}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}{shared_key}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}{shared_key}.down_proj.weight') + + shared_module = deep_getattr(mg_mlp, 'shared_experts') + if shared_module is not None: + fc1_weight = torch.stack([gate_weight, up_weight], + dim=0) + self._set_weight(shared_module.linear_fc1.weight, + fc1_weight, 'linear_fc1.weight') + self._set_weight(shared_module.linear_fc2.weight, + down_weight, 'linear_fc2.weight') + break + except KeyError: + continue + + # Load shared expert gate if present + for gate_key in ['shared_expert_gate.weight']: + full_gate_key = f'{prefix}{gate_key}' + if full_gate_key in loader: + try: + gate_weight = loader.get_tensor(full_gate_key) + shared_module = deep_getattr(mg_mlp, 'shared_experts') + if shared_module is not None and hasattr( + shared_module, 'gate_weight'): + shared_module.gate_weight.data.copy_(gate_weight) + break + except KeyError: + continue + + # Load experts with EP sharding + num_local_experts = self.config.num_experts // self.ep_size + start_expert_idx = self.ep_rank * num_local_experts + experts_module = deep_getattr(mg_mlp, 'experts') + + if experts_module is not None: + # Determine expert module type + if hasattr(experts_module, 'weight1'): + # GroupedMLP format - weights are merged: [hidden, num_experts * ffn_hidden] + # Need to collect all experts and set at once + fc1_weights = [] # gate and up weights interleaved + fc2_weights = [] # down weights + + for local_idx in range(num_local_experts): + global_idx = start_expert_idx + local_idx + try: + gate_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.down_proj.weight') + + # Stack gate and up for gated linear unit + fc1_weights.append(gate_weight) # [ffn_hidden, hidden] + fc1_weights.append(up_weight) # [ffn_hidden, hidden] + fc2_weights.append(down_weight) # [hidden, ffn_hidden] + except KeyError as e: + print( + f'Warning: Missing expert {global_idx} weights: {e}' + ) + continue + + if fc1_weights and fc2_weights: + # GroupedMLP weight1: [hidden, num_experts * 2 * ffn_hidden] (transposed) + # HF format: [num_experts * 2, ffn_hidden, hidden] + fc1_stacked = torch.cat( + fc1_weights, + dim=0) # [num_experts*2*ffn_hidden, hidden] + fc1_stacked = fc1_stacked.t().contiguous( + ) # [hidden, num_experts*2*ffn_hidden] + + # GroupedMLP weight2: [num_experts * ffn_hidden, hidden] + fc2_stacked = torch.cat( + fc2_weights, dim=0) # [num_experts*hidden, ffn_hidden] + + # Set weights directly + if experts_module.weight1.shape == fc1_stacked.shape: + experts_module.weight1.data.copy_(fc1_stacked) + else: + # Handle TP split + tp_rank = self.etp_rank + tp_size = self.etp_size + if tp_size > 1: + # Split along last dim for weight1 + chunk_size = fc1_stacked.shape[1] // tp_size + fc1_chunk = fc1_stacked[:, tp_rank * + chunk_size:(tp_rank + 1) * + chunk_size] + experts_module.weight1.data.copy_(fc1_chunk) + else: + experts_module.weight1.data.copy_(fc1_stacked) + + if experts_module.weight2.shape == fc2_stacked.shape: + experts_module.weight2.data.copy_(fc2_stacked) + else: + # Handle TP split + tp_rank = self.etp_rank + tp_size = self.etp_size + if tp_size > 1: + # Split along first dim for weight2 + chunk_size = fc2_stacked.shape[0] // tp_size + fc2_chunk = fc2_stacked[tp_rank * + chunk_size:(tp_rank + 1) * + chunk_size, :] + experts_module.weight2.data.copy_(fc2_chunk) + else: + experts_module.weight2.data.copy_(fc2_stacked) + + elif hasattr(experts_module, 'local_experts'): + # SequentialMLP format with local_experts list + for local_idx in range(num_local_experts): + global_idx = start_expert_idx + local_idx + try: + gate_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.down_proj.weight') + + expert = experts_module.local_experts[local_idx] + if hasattr(expert, 'linear_fc1'): + fc1_weight = torch.stack([gate_weight, up_weight], + dim=0) + self._set_weight(expert.linear_fc1.weight, + fc1_weight, 'linear_fc1.weight') + self._set_weight(expert.linear_fc2.weight, + down_weight, 'linear_fc2.weight') + except KeyError: + continue + + elif hasattr(experts_module, 'linear_fc1'): + # TEGroupedLinear format - weights stored as weight0, weight1, etc. + for local_idx in range(num_local_experts): + global_idx = start_expert_idx + local_idx + try: + gate_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.gate_proj.weight') + up_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.up_proj.weight') + down_weight = loader.get_tensor( + f'{prefix}experts.{global_idx}.down_proj.weight') + + fc1_weight = torch.stack([gate_weight, up_weight], + dim=0) + fc1_param = getattr(experts_module.linear_fc1, + f'weight{local_idx}', None) + if fc1_param is not None: + self._set_weight(fc1_param, + fc1_weight, + 'linear_fc1.weight', + is_expert=True) + + fc2_param = getattr(experts_module.linear_fc2, + f'weight{local_idx}', None) + if fc2_param is not None: + self._set_weight(fc2_param, + down_weight, + 'linear_fc2.weight', + is_expert=True) + except KeyError: + continue + + # Load post attention layernorm (pre_mlp_layernorm for MoE) + ln_key = f'{self.HF_LAYERS_PREFIX}.{layer_idx}.post_attention_layernorm.weight' + try: + ln_weight = loader.get_tensor(ln_key) + # Try pre_mlp_layernorm first (used in MoE layers) + ln_module = deep_getattr(mg_layer, 'pre_mlp_layernorm') + if ln_module is not None and hasattr(ln_module, 'weight'): + ln_module.weight.data.copy_(ln_weight) + else: + # Fallback to linear_fc1.layer_norm_weight + ln_param = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln_param is not None: + ln_param.data.copy_(ln_weight) + except KeyError: + pass + + def _load_layer(self, mg_layer, loader: SafetensorLoader, layer_idx: int): + """Load a single transformer layer.""" + self._load_attention(mg_layer, loader, layer_idx) + + # Check if MoE layer + if self.config.num_experts > 0: + self._load_moe(mg_layer, loader, layer_idx) + else: + self._load_mlp(mg_layer, loader, layer_idx) + + def load_weights( + self, + mg_model: nn.Module, + model_path: str, + is_peft_format: bool = False, + adapter_name: str = 'default', + ) -> None: + """Load HuggingFace weights into Megatron model. + + Args: + mg_model: Megatron GPT model. + model_path: Path to HuggingFace checkpoint. + is_peft_format: Whether loading PEFT adapter weights. + adapter_name: Name of the adapter for PEFT. + """ + self._is_peft_format = is_peft_format + self._adapter_name = adapter_name + + with torch.no_grad(): + with SafetensorLoader(model_path, + is_peft_format=is_peft_format) as loader: + if is_peft_format: + self._load_peft_weights(mg_model, loader) + else: + self._load_base_weights(mg_model, loader) + + def _load_base_weights(self, mg_model: nn.Module, + loader: SafetensorLoader): + """Load base model weights.""" + # Get decoder + decoder = deep_getattr(mg_model, 'decoder') + if decoder is None: + decoder = mg_model + + layers = getattr(decoder, 'layers', []) + + # Load pre-process (embedding) on first PP rank + if self.pp_size <= 1 or self.pp_rank == 0: + try: + self._load_embedding(mg_model, loader) + except Exception as e: + print(f'Warning: Failed to load embedding: {e}') + + # Load transformer layers + prog_bar = tqdm(layers, + desc='Loading weights', + disable=self.disable_tqdm) + for mg_layer in prog_bar: + layer_idx = mg_layer.layer_number - 1 # 1-indexed to 0-indexed + try: + self._load_layer(mg_layer, loader, layer_idx) + except Exception as e: + print(f'Warning: Failed to load layer {layer_idx}: {e}') + + # Load post-process on last PP rank + if self.pp_size <= 1 or self.pp_rank == self.pp_size - 1: + try: + self._load_final_layernorm(mg_model, loader) + self._load_output_layer(mg_model, loader) + except Exception as e: + print(f'Warning: Failed to load post-process: {e}') + + def _load_peft_weights(self, mg_model: nn.Module, + loader: SafetensorLoader): + """Load PEFT/LoRA adapter weights.""" + state_dict = loader.get_state_dict() + hf_prefix = 'base_model.model.' if self._is_peft_format else '' + + # Build mapping from HF keys to Megatron keys + for key, lazy_tensor in state_dict.items(): + # Remove base_model.model. prefix + if key.startswith(hf_prefix): + key = key[len(hf_prefix):] + + # Parse the key to find target module + if '.lora_A.' in key or '.lora_B.' in key: + tensor = lazy_tensor.load() + self._load_peft_tensor(mg_model, key, tensor) + + def _load_peft_tensor(self, mg_model: nn.Module, key: str, + tensor: torch.Tensor): + """Load a single PEFT tensor into the model.""" + # Parse key: model.layers.0.self_attn.q_proj.lora_A.weight + parts = key.split('.') + + # Find layer index + layer_idx = None + for i, p in enumerate(parts): + if p == 'layers' and i + 1 < len(parts): + layer_idx = int(parts[i + 1]) + break + + if layer_idx is None: + return + + # Get layer + decoder = deep_getattr(mg_model, 'decoder') + if decoder is None: + decoder = mg_model + + layers = getattr(decoder, 'layers', []) + for layer in layers: + if layer.layer_number - 1 == layer_idx: + mg_layer = layer + break + else: + return + + # Determine target and lora type + is_lora_A = '.lora_A.' in key + is_lora_B = '.lora_B.' in key + + if 'self_attn' in key: + mg_attn = mg_layer.self_attention + if 'q_proj' in key or 'k_proj' in key or 'v_proj' in key: + target = deep_getattr(mg_attn, 'linear_qkv') + elif 'o_proj' in key: + target = deep_getattr(mg_attn, 'linear_proj') + else: + return + elif 'mlp' in key: + mg_mlp = mg_layer.mlp + if 'gate_proj' in key or 'up_proj' in key: + target = deep_getattr(mg_mlp, 'linear_fc1') + elif 'down_proj' in key: + target = deep_getattr(mg_mlp, 'linear_fc2') + else: + return + else: + return + + if target is None: + return + + # Get LoRA module + if is_lora_A: + lora_module = deep_getattr(target, f'lora_A.{self._adapter_name}') + else: + lora_module = deep_getattr(target, f'lora_B.{self._adapter_name}') + + if lora_module is not None and hasattr(lora_module, 'weight'): + lora_module.weight.data.copy_(tensor) + + # ========================================================================= + # Weight Saving Methods + # ========================================================================= + + def export_weights( + self, + mg_models: Union[nn.Module, List[nn.Module]], + target_device: Optional[str] = None, + only_last_rank: bool = False, + is_peft_format: bool = False, + tqdm_desc: str = 'Exporting: ', + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Export weights from Megatron model to HuggingFace format. + + Yields: + Tuples of (key, tensor) for each weight. + """ + self._target_device = target_device + self._only_last_rank = only_last_rank + self._is_peft_format = is_peft_format + self._adapter_name = 'default' + self._peft_target_modules = set() + self._peft_modules_to_save = set() + + if not isinstance(mg_models, (list, tuple)): + mg_models = [mg_models] + + hf_prefix = 'base_model.model.' if is_peft_format else '' + + with torch.no_grad(): + # For now, handle single model + mg_model = mg_models[0] + + decoder = deep_getattr(mg_model, 'decoder') + if decoder is None: + decoder = mg_model + + layers = getattr(decoder, 'layers', []) + + if not is_peft_format: + # Export embedding + if self.pp_size <= 1 or self.pp_rank == 0: + embed = deep_getattr(mg_model, + 'embedding.word_embeddings.weight') + if embed is not None: + weight, _ = self._get_weight(embed.data, + 'word_embeddings.weight') + if weight is not None: + weight = weight[:self.config.vocab_size] + yield f'{hf_prefix}{self.HF_EMBED_KEY}', weight + + # Export layers + prog_bar = tqdm(layers, desc=tqdm_desc, disable=self.disable_tqdm) + for mg_layer in prog_bar: + layer_idx = mg_layer.layer_number - 1 + yield from self._export_layer(mg_layer, layer_idx, hf_prefix, + is_peft_format) + + if not is_peft_format: + # Export final layernorm and output layer + if self.pp_size <= 1 or self.pp_rank == self.pp_size - 1: + ln_module = deep_getattr(mg_model, + 'decoder.final_layernorm') + if ln_module is not None: + yield f'{hf_prefix}{self.HF_FINAL_LAYERNORM_KEY}', ln_module.weight.data.clone( + ) + + output = deep_getattr(mg_model, 'output_layer.weight') + if output is not None: + weight, _ = self._get_weight(output.data, + 'output_layer.weight') + if weight is not None: + weight = weight[:self.config.vocab_size] + yield f'{hf_prefix}{self.HF_LM_HEAD_KEY}', weight + + def _export_layer( + self, + mg_layer, + layer_idx: int, + hf_prefix: str, + is_peft_format: bool, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Export a single layer.""" + prefix = f'{hf_prefix}{self.HF_LAYERS_PREFIX}.{layer_idx}.' + + mg_attn = mg_layer.self_attention + mg_mlp = mg_layer.mlp + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + head_dim = hidden_size // num_heads + heads_per_group = num_heads // num_kv_heads + q_dim = heads_per_group * head_dim + kv_dim = head_dim + + if not is_peft_format: + # Export QKV + qkv_weight, _ = self._get_weight(mg_attn.linear_qkv.weight.data, + 'linear_qkv.weight') + if qkv_weight is not None: + qkv_weight = qkv_weight.reshape(num_kv_heads, -1, hidden_size) + yield f'{prefix}self_attn.q_proj.weight', qkv_weight[:, : + q_dim, :].reshape( + -1, + hidden_size + ).clone() + yield f'{prefix}self_attn.k_proj.weight', qkv_weight[:, q_dim: + q_dim + + kv_dim, :].reshape( + -1, + hidden_size + ).clone() + yield f'{prefix}self_attn.v_proj.weight', qkv_weight[:, + -kv_dim:, :].reshape( + -1, + hidden_size + ).clone() + + # Export O + o_weight, _ = self._get_weight(mg_attn.linear_proj.weight.data, + 'linear_proj.weight') + if o_weight is not None: + yield f'{prefix}self_attn.o_proj.weight', o_weight + + # Export layernorms + ln = deep_getattr(mg_attn, 'linear_qkv.layer_norm_weight') + if ln is not None: + yield f'{prefix}input_layernorm.weight', ln.data.clone() + + # Export MLP + fc1_weight, _ = self._get_weight(mg_mlp.linear_fc1.weight.data, + 'linear_fc1.weight') + if fc1_weight is not None: + fc1_weight = fc1_weight.view(2, -1, hidden_size) + yield f'{prefix}mlp.gate_proj.weight', fc1_weight[0].clone() + yield f'{prefix}mlp.up_proj.weight', fc1_weight[1].clone() + + fc2_weight, _ = self._get_weight(mg_mlp.linear_fc2.weight.data, + 'linear_fc2.weight') + if fc2_weight is not None: + yield f'{prefix}mlp.down_proj.weight', fc2_weight + + ln2 = deep_getattr(mg_mlp, 'linear_fc1.layer_norm_weight') + if ln2 is not None: + yield f'{prefix}post_attention_layernorm.weight', ln2.data.clone( + ) + else: + # Export LoRA weights only + yield from self._export_lora_layer(mg_attn, mg_mlp, prefix) + + def _export_lora_layer( + self, + mg_attn, + mg_mlp, + prefix: str, + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Export LoRA weights from a layer.""" + # Check if LoRA is applied + from twinkle.megatron.tuners import LoraParallelLinear + + # Attention LoRA + if isinstance(mg_attn.linear_qkv, LoraParallelLinear): + lora_A = deep_getattr(mg_attn.linear_qkv, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_attn.linear_qkv, + f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, + 'linear_qkv.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_qkv.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.update( + {'q_proj', 'k_proj', 'v_proj'}) + # Split lora_B for Q, K, V + for key in ['q_proj', 'k_proj', 'v_proj']: + yield f'{prefix}self_attn.{key}.lora_A.weight', lora_A.clone( + ) + + num_kv_heads = self.config.num_key_value_heads + head_dim = self.config.hidden_size // self.config.num_attention_heads + heads_per_group = self.config.num_attention_heads // num_kv_heads + q_dim = heads_per_group * head_dim + + lora_B = lora_B.reshape(num_kv_heads, -1, lora_B.shape[-1]) + yield f'{prefix}self_attn.q_proj.lora_B.weight', lora_B[:, :q_dim, :].reshape( + -1, lora_B.shape[-1]).clone() + yield f'{prefix}self_attn.k_proj.lora_B.weight', lora_B[:, + q_dim: + -head_dim, :].reshape( + -1, + lora_B + . + shape[ + -1] + ).clone( + ) + yield f'{prefix}self_attn.v_proj.lora_B.weight', lora_B[:, -head_dim:, :].reshape( + -1, lora_B.shape[-1]).clone() + + # O projection LoRA + if isinstance(mg_attn.linear_proj, LoraParallelLinear): + lora_A = deep_getattr(mg_attn.linear_proj, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_attn.linear_proj, + f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, + 'linear_proj.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_proj.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.add('o_proj') + yield f'{prefix}self_attn.o_proj.lora_A.weight', lora_A.clone( + ) + yield f'{prefix}self_attn.o_proj.lora_B.weight', lora_B.clone( + ) + + # MLP LoRA + if hasattr(mg_mlp, 'linear_fc1') and isinstance( + mg_mlp.linear_fc1, LoraParallelLinear): + lora_A = deep_getattr(mg_mlp.linear_fc1, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_mlp.linear_fc1, + f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, + 'linear_fc1.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_fc1.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.update({'gate_proj', 'up_proj'}) + for key in ['gate_proj', 'up_proj']: + yield f'{prefix}mlp.{key}.lora_A.weight', lora_A.clone( + ) + + lora_B = lora_B.reshape(2, -1, lora_B.shape[-1]) + yield f'{prefix}mlp.gate_proj.lora_B.weight', lora_B[ + 0].clone() + yield f'{prefix}mlp.up_proj.lora_B.weight', lora_B[ + 1].clone() + + if hasattr(mg_mlp, 'linear_fc2') and isinstance( + mg_mlp.linear_fc2, LoraParallelLinear): + lora_A = deep_getattr(mg_mlp.linear_fc2, + f'lora_A.{self._adapter_name}.weight') + lora_B = deep_getattr(mg_mlp.linear_fc2, + f'lora_B.{self._adapter_name}.weight') + + if lora_A is not None and lora_B is not None: + lora_A, _ = self._get_weight(lora_A.data, + 'linear_fc2.lora_A.weight') + lora_B, _ = self._get_weight(lora_B.data, + 'linear_fc2.lora_B.weight') + + if lora_A is not None: + self._peft_target_modules.add('down_proj') + yield f'{prefix}mlp.down_proj.lora_A.weight', lora_A.clone( + ) + yield f'{prefix}mlp.down_proj.lora_B.weight', lora_B.clone( + ) + + def save_weights( + self, + mg_models: Union[nn.Module, List[nn.Module]], + output_dir: str, + is_peft_format: bool = False, + ) -> None: + """Save Megatron model weights in HuggingFace format. + + Args: + mg_models: Megatron model(s) to save. + output_dir: Directory to save weights. + is_peft_format: Whether saving in PEFT format. + + Note: + For DP > 1, only DP rank 0 writes to disk. All ranks participate + in tensor gather operations for TP. + """ + torch.cuda.empty_cache() + + # Determine if this rank should write + should_write = is_last_rank() + + # Only the writing rank creates the saver + saver = None + if should_write: + saver = StreamingSafetensorSaver( + save_dir=output_dir, + max_shard_size=self.config.max_shard_size, + is_peft_format=is_peft_format, + ) + + # All ranks participate in export (needed for TP gather) + for key, tensor in self.export_weights( + mg_models, + target_device='cpu', + only_last_rank=True, + is_peft_format=is_peft_format, + tqdm_desc='Saving: ', + ): + if saver is not None and tensor is not None: + saver.add_tensor(key, tensor) + + if saver is not None: + saver.finalize() + + # Save config on writing rank only + if should_write: + if is_peft_format and not isinstance(mg_models, (list, tuple)): + mg_models = [mg_models] + + if is_peft_format and hasattr(mg_models[0], 'peft_config'): + peft_config = copy(mg_models[0].peft_config.get( + self._adapter_name)) + if peft_config is not None: + peft_config.target_modules = list( + self._peft_target_modules) + peft_config.modules_to_save = list( + self._peft_modules_to_save) + peft_config.save_pretrained(output_dir) + elif not is_peft_format and self.hf_config is not None: + # Save HF config + self.hf_config.vocab_size = self.config.padded_vocab_size + self.hf_config.save_pretrained(output_dir) + + # Synchronize all ranks before continuing + if dist.is_initialized(): + dist.barrier() + + +class TwinkleBridgeAdapter: + """Adapter for weight loading using TwinkleGPTBridge. + + Provides a simple interface for loading HF weights into Megatron models. + """ + def __init__( + self, + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + etp_size: Optional[int] = None, + model_path: Optional[str] = None, + padded_vocab_size: Optional[int] = None, + **kwargs, + ): + """Initialize the bridge adapter.""" + self.hf_config = hf_config + self.model_path = model_path + + # Create bridge config + self.config = BridgeConfig.from_hf_config( + hf_config=hf_config, + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + padded_vocab_size=padded_vocab_size, + ) + if etp_size is not None: + self.config.etp_size = etp_size + + self._bridge = None + + def _get_bridge(self) -> TwinkleGPTBridge: + """Get or create the bridge instance.""" + if self._bridge is None: + self._bridge = TwinkleGPTBridge( + config=self.config, + hf_config=self.hf_config, + ) + return self._bridge + + def load_weights( + self, + mg_model: nn.Module, + model_path: Optional[str] = None, + is_peft_format: bool = False, + adapter_name: str = 'default', + ) -> None: + """Load HuggingFace weights into Megatron model.""" + model_path = model_path or self.model_path + if model_path is None: + raise ValueError('model_path must be provided') + + bridge = self._get_bridge() + bridge.load_weights(mg_model, model_path, is_peft_format, adapter_name) + + def save_weights( + self, + mg_models: Union[nn.Module, List[nn.Module]], + output_dir: str, + is_peft_format: bool = False, + ) -> None: + """Save Megatron model weights in HuggingFace format.""" + bridge = self._get_bridge() + bridge.save_weights(mg_models, output_dir, is_peft_format) + + +class TwinkleBridgeInitializer: + """ + Megatron model initializer. + + This class provides complete model initialization flow including: + - Megatron parallel state initialization + - Model creation from HuggingFace config + - Weight loading using TwinkleGPTBridge + + Example: + initializer = TwinkleBridgeInitializer( + tp_size=2, + pp_size=1, + params_dtype=torch.bfloat16, + ) + model = initializer.create_model('Qwen/Qwen2.5-7B-Instruct') + """ + def __init__( + self, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + etp_size: Optional[int] = None, + params_dtype=None, + use_cpu_initialization: bool = False, + attention_backend: str = 'flash', + sequence_parallel: bool = False, + recompute_granularity: Optional[str] = 'selective', + recompute_modules: Optional[list] = None, + recompute_method: Optional[str] = None, + recompute_num_layers: Optional[int] = None, + ): + """Initialize TwinkleBridgeInitializer. + + Args: + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + cp_size: Context parallel size. + ep_size: Expert parallel size. + etp_size: Expert tensor parallel size. + params_dtype: Parameter dtype (default: torch.bfloat16). + use_cpu_initialization: Initialize on CPU first. + attention_backend: Attention backend. + sequence_parallel: Enable sequence parallelism. Required for MoE with TP > 1. + recompute_granularity: Activation recomputation strategy. + 'selective' (default): Only recompute core attention (memory efficient). + 'full': Recompute entire transformer layer (most memory efficient). + None: No recomputation (fastest but highest memory). + recompute_modules: Modules to recompute when using 'selective' granularity. + Default: ['core_attn'] for efficient memory/compute trade-off. + recompute_method: Method for full recompute ('uniform' or 'block'). + Required when recompute_granularity='full'. + recompute_num_layers: Number of layers to recompute for 'full' mode. + Required when recompute_granularity='full'. + """ + self.tp_size = tp_size + self.pp_size = pp_size + self.cp_size = cp_size + self.ep_size = ep_size + self.etp_size = etp_size or tp_size + self.params_dtype = params_dtype if params_dtype is not None else torch.bfloat16 + self.use_cpu_initialization = use_cpu_initialization + self.attention_backend = attention_backend + self.sequence_parallel = sequence_parallel + self.recompute_granularity = recompute_granularity + self.recompute_modules = recompute_modules or ['core_attn'] + self.recompute_method = recompute_method + self.recompute_num_layers = recompute_num_layers + + self._model = None + self._bridge = None + self._hf_config = None + self._model_path = None + + def _download_model(self, model_path: str) -> str: + """Download model if it's a model ID.""" + if os.path.isdir(model_path): + return model_path + + try: + from modelscope import snapshot_download + return snapshot_download(model_path) + except ImportError: + from huggingface_hub import snapshot_download + return snapshot_download(model_path) + + def _initialize_megatron(self, hf_config: Any = None): + """Initialize Megatron parallel state. + + This sets up the required process groups for tensor, pipeline, + and data parallelism using Megatron's parallel state module directly. + + Handles both local (torchrun) and Ray execution modes: + - Local: Uses torchrun's environment variables (already set) + - Ray: Uses RayHelper's environment variables (RANK, WORLD_SIZE, etc.) + + Args: + hf_config: Optional HuggingFace config for additional model parameters. + """ + import os + import torch.distributed as dist + from datetime import timedelta + from megatron.core import parallel_state as mpu + from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + + # Check if already initialized + try: + if mpu.is_initialized(): + return + except AssertionError: + pass + + # Determine execution mode + twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') + + # Initialize distributed if not already + if not dist.is_initialized(): + if twinkle_mode == 'ray': + # Ray mode: use environment variables set by RayHelper + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + master_addr = os.environ.get('MASTER_ADDR', 'localhost') + master_port = os.environ.get('MASTER_PORT', '29500') + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + + # Set CUDA device before init_process_group + torch.cuda.set_device(local_rank) + + # Initialize process group with explicit parameters + dist.init_process_group( + backend='nccl', + init_method=f'tcp://{master_addr}:{master_port}', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + else: + # Local mode (torchrun): environment variables are already set + dist.init_process_group(backend='nccl') + + # Initialize Megatron parallel state directly + mpu.initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + ) + + # Initialize CUDA RNG tracker for tensor parallel random states + # This is required when use_cpu_initialization=False (GPU initialization) + model_parallel_cuda_manual_seed(42) + + def _create_model_from_config( + self, + hf_config: Any, + padded_vocab_size: int, + ) -> nn.Module: + """Create Megatron GPT model from HuggingFace config. + + Args: + hf_config: HuggingFace model configuration. + padded_vocab_size: Padded vocabulary size. + + Returns: + Megatron GPT model. + """ + import torch.distributed as dist + from megatron.core import parallel_state as mpu + from megatron.core.transformer import TransformerConfig + from megatron.core.transformer.enums import AttnBackend + from megatron.core.models.gpt import GPTModel + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, ) + + # Convert HF config to Megatron config + from ..utils import convert_hf_config + mg_config_dict = convert_hf_config(hf_config) + + # Build TransformerConfig + num_attention_heads = mg_config_dict['num_attention_heads'] + num_query_groups = mg_config_dict.get('num_query_groups', + num_attention_heads) + num_layers = mg_config_dict['num_layers'] + + # Configure activation recomputation + recompute_method = self.recompute_method + recompute_num_layers = self.recompute_num_layers + + # Auto-configure for 'full' recomputation if not specified + if self.recompute_granularity == 'full': + if recompute_method is None: + recompute_method = 'uniform' + if recompute_num_layers is None: + # Recompute all layers for maximum memory savings + recompute_num_layers = num_layers // self.pp_size + + # Create finalize_model_grads function for DP gradient synchronization + # Megatron's native finalize_model_grads requires DDP-wrapped models with ddp_config. + # For PEFT/LoRA models, we use a custom implementation that handles non-DDP models. + from megatron.core.distributed import finalize_model_grads as _native_finalize_model_grads + + def finalize_model_grads_for_lora(model, + num_tokens=None, + pg_collection=None): + """Finalize model grads that handles both DDP and PEFT/LoRA models. + + For DDP-wrapped models: Delegates to Megatron's native finalize_model_grads + For PEFT/LoRA models: Manually all-reduce gradients across DP ranks + + This is necessary because PEFT models don't have ddp_config attribute + that Megatron's native implementation expects. + """ + from megatron.core import parallel_state as mpu + + # Check if model is DDP-wrapped (has ddp_config) + if hasattr(model[0], 'ddp_config'): + # Use native implementation for DDP models + return _native_finalize_model_grads(model, num_tokens, + pg_collection) + + # For PEFT/LoRA models, call finish_grad_sync on each chunk + # The model should have finish_grad_sync added by MegatronModel.add_adapter_to_model + for model_chunk in model: + if hasattr(model_chunk, 'finish_grad_sync'): + model_chunk.finish_grad_sync() + + # MoE configuration + num_experts = mg_config_dict.get('num_experts', 0) or 0 + moe_ffn_hidden_size = mg_config_dict.get('moe_ffn_hidden_size') + moe_router_topk = mg_config_dict.get('moe_router_topk', 2) or 2 + moe_shared_expert_intermediate_size = mg_config_dict.get( + 'moe_shared_expert_intermediate_size') + + # Build MoE-related kwargs + moe_kwargs = {} + if num_experts > 0: + moe_kwargs.update({ + 'num_moe_experts': + num_experts, + 'moe_router_topk': + moe_router_topk, + 'moe_router_load_balancing_type': + mg_config_dict.get('moe_router_load_balancing_type', + 'aux_loss'), + # MoE performance optimizations + 'moe_token_dispatcher_type': + mg_config_dict.get( + 'moe_token_dispatcher_type', 'alltoall' + ), # 'alltoall' is more efficient than 'allgather' + 'moe_grouped_gemm': + mg_config_dict.get( + 'moe_grouped_gemm', True + ), # Enable for better performance (requires grouped_gemm package) + 'moe_aux_loss_coeff': + mg_config_dict.get( + 'moe_aux_loss_coeff', + 0.0), # Auxiliary load balancing loss coefficient + }) + + # FFN hidden size for MoE + if moe_ffn_hidden_size: + moe_kwargs['moe_ffn_hidden_size'] = moe_ffn_hidden_size + + # Shared expert configuration + if moe_shared_expert_intermediate_size: + moe_kwargs[ + 'moe_shared_expert_intermediate_size'] = moe_shared_expert_intermediate_size + + # Router score function (sigmoid for Qwen3, softmax for others) + if mg_config_dict.get('moe_router_score_function'): + moe_kwargs['moe_router_score_function'] = mg_config_dict[ + 'moe_router_score_function'] + + # Expert bias for sigmoid router + if mg_config_dict.get('moe_router_enable_expert_bias'): + moe_kwargs['moe_router_enable_expert_bias'] = mg_config_dict[ + 'moe_router_enable_expert_bias'] + + # Sequence parallel requires TP > 1 + # Auto-enable for MoE with TP > 1 (required by Megatron) + use_sequence_parallel = self.sequence_parallel and self.tp_size > 1 + if num_experts > 0 and self.tp_size > 1 and not use_sequence_parallel: + use_sequence_parallel = True + print( + f'Auto-enabling sequence_parallel for MoE with TP={self.tp_size}' + ) + + # For MoE models, ffn_hidden_size should be moe_ffn_hidden_size if not specified + ffn_hidden_size = mg_config_dict.get('ffn_hidden_size') + if ffn_hidden_size is None: + ffn_hidden_size = moe_ffn_hidden_size or ( + 4 * mg_config_dict['hidden_size']) + + # For models with non-standard head dimensions (like Qwen3-30B-A3B) + kv_channels = mg_config_dict.get('kv_channels') + + # Activation function for SwiGLU (required by Megatron when gated_linear_unit=True) + use_swiglu = mg_config_dict.get('swiglu', True) + activation_func = torch.nn.functional.silu if use_swiglu else torch.nn.functional.gelu + + # Enable bias_activation_fusion for SwiGLU + # Note: Only works with TransformerEngine and no bias in linear layers + has_bias = not mg_config_dict.get('disable_bias_linear', True) + bias_activation_fusion = use_swiglu and not has_bias + + config = TransformerConfig( + num_layers=num_layers, + hidden_size=mg_config_dict['hidden_size'], + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + sequence_parallel=use_sequence_parallel, + params_dtype=self.params_dtype, + pipeline_dtype=self. + params_dtype, # Required when using pipeline parallelism + use_cpu_initialization=self.use_cpu_initialization, + add_qkv_bias=mg_config_dict.get('add_qkv_bias', False), + add_bias_linear=not mg_config_dict.get('disable_bias_linear', + True), + gated_linear_unit=use_swiglu, + activation_func=activation_func, # SiLU for SwiGLU, GELU otherwise + bias_activation_fusion= + bias_activation_fusion, # Fused SwiGLU for performance + normalization='RMSNorm', + layernorm_epsilon=mg_config_dict.get('norm_epsilon', 1e-6), + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + hidden_dropout=0.0, + attention_dropout=0.0, + # Performance optimizations + masked_softmax_fusion=True, # Fused attention softmax + bias_dropout_fusion=True, # Fused bias + dropout + apply_rope_fusion=True, # Fused RoPE application + attention_softmax_in_fp32=True, # Numerical stability + attention_backend=AttnBackend.flash, # FlashAttention for speed + # Activation recomputation for memory efficiency + recompute_granularity=self.recompute_granularity, + recompute_modules=self.recompute_modules + if self.recompute_granularity == 'selective' else None, + recompute_method=recompute_method, + recompute_num_layers=recompute_num_layers, + # Critical: Set finalize_model_grads_func for DP gradient synchronization + # Uses custom wrapper that handles both DDP and PEFT/LoRA models + finalize_model_grads_func=finalize_model_grads_for_lora, + # MoE configuration + **moe_kwargs, + ) + + # Save transformer config for later use (e.g., DDP wrapping) + self._transformer_config = config + + # Get layer spec - enable moe_grouped_gemm for MoE models + moe_grouped_gemm = num_experts > 0 + try: + layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=mg_config_dict.get('num_experts'), + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=mg_config_dict.get('qk_layernorm', False), + ) + except (ImportError, AttributeError): + raise RuntimeError("TransformerEngine is not installed or not compatible with this version of Megatron-Core.") + + # Create model + max_seq_length = getattr(hf_config, 'max_position_embeddings', 4096) + rotary_base = mg_config_dict.get('rotary_base', 10000) + + model = GPTModel( + config=config, + transformer_layer_spec=layer_spec, + vocab_size=padded_vocab_size, + max_sequence_length=max_seq_length, + pre_process=mpu.is_pipeline_first_stage(), + post_process=mpu.is_pipeline_last_stage(), + parallel_output=True, + share_embeddings_and_output_weights=getattr( + hf_config, 'tie_word_embeddings', False), + position_embedding_type='rope', + rotary_base=rotary_base, + ) + + return model + + def _pad_vocab_size(self, vocab_size: int) -> int: + """Pad vocab size for tensor parallelism.""" + divisor = self.tp_size * 128 + return ((vocab_size + divisor - 1) // divisor) * divisor + + def create_model( + self, + model_path: str, + load_weights: bool = True, + ) -> nn.Module: + """Create Megatron model from HuggingFace checkpoint. + + Args: + model_path: Path to HuggingFace model or model ID. + load_weights: Whether to load weights. + + Returns: + Megatron model. + """ + from transformers import AutoConfig + + # Download model if needed + model_path = HubOperation.download_model(model_path) + self._model_path = model_path + + # Load HF config first (needed for initialization) + self._hf_config = AutoConfig.from_pretrained(model_path, + trust_remote_code=True) + + # Initialize Megatron parallel state with hf_config for proper args setup + self._initialize_megatron(self._hf_config) + + # Calculate padded vocab size + padded_vocab_size = self._pad_vocab_size(self._hf_config.vocab_size) + + # Create model + self._model = self._create_model_from_config(self._hf_config, + padded_vocab_size) + + # Load weights + if load_weights: + bridge_adapter = TwinkleBridgeAdapter( + hf_config=self._hf_config, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + etp_size=self.etp_size, + model_path=model_path, + padded_vocab_size=padded_vocab_size, + ) + bridge_adapter.load_weights(self._model, model_path) + self._bridge = bridge_adapter._get_bridge() + + # Synchronize all ranks after model creation and weight loading + # This is critical for Pipeline Parallel to ensure all ranks are ready + # before any collective communication operations + if dist.is_initialized(): + dist.barrier() + + return self._model + + @property + def hf_config(self): + """Get the HuggingFace config.""" + return self._hf_config + + @property + def bridge(self): + """Get the bridge instance.""" + return self._bridge + + def load_weights(self, model: nn.Module, model_path: str): + """Load weights into an existing model. + + Args: + model: Megatron model. + model_path: Path to HuggingFace checkpoint. + """ + if self._bridge is None and self._hf_config is None: + raise ValueError('Must call create_model first') + + padded_vocab_size = self._pad_vocab_size(self._hf_config.vocab_size) + bridge_adapter = TwinkleBridgeAdapter( + hf_config=self._hf_config, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + model_path=model_path, + padded_vocab_size=padded_vocab_size, + ) + bridge_adapter.load_weights(model, model_path) + + def save_weights(self, + models: Union[nn.Module, List[nn.Module]], + output_dir: str, + is_peft_format: bool = False): + """Save weights in HuggingFace format. + + Args: + models: Megatron model(s). + output_dir: Output directory. + is_peft_format: Whether to save in PEFT format. + """ + if self._bridge is None: + raise ValueError('Must load weights first') + + if not isinstance(models, (list, tuple)): + models = [models] + + self._bridge.save_weights(models, + output_dir, + is_peft_format=is_peft_format) + + +# Legacy functions for backward compatibility +def create_megatron_args(*args, **kwargs) -> SimpleNamespace: + """Legacy function - use BridgeConfig instead.""" + return SimpleNamespace(**kwargs) + + +def set_megatron_args(args: SimpleNamespace) -> None: + """Legacy function - no longer needed with TwinkleGPTBridge.""" + pass + + +def restore_megatron_args() -> None: + """Legacy function - no longer needed with TwinkleGPTBridge.""" + pass + + +def mock_megatron_args(args: SimpleNamespace): + """Legacy function - no longer needed with TwinkleGPTBridge.""" + from contextlib import contextmanager + + @contextmanager + def noop(): + yield args + + return noop() + + +def load_hf_weights_to_megatron( + mg_model: nn.Module, + model_path: str, + hf_config: Any, + tp_size: int = 1, + pp_size: int = 1, + ep_size: int = 1, + padded_vocab_size: Optional[int] = None, +) -> None: + """Convenience function to load HF weights into Megatron model.""" + adapter = TwinkleBridgeAdapter( + hf_config=hf_config, + tp_size=tp_size, + pp_size=pp_size, + ep_size=ep_size, + model_path=model_path, + padded_vocab_size=padded_vocab_size, + ) + adapter.load_weights(mg_model, model_path) diff --git a/src/twinkle/megatron/model/initializer.py b/src/twinkle/megatron/model/initializer.py new file mode 100644 index 00000000..215a4eb5 --- /dev/null +++ b/src/twinkle/megatron/model/initializer.py @@ -0,0 +1,335 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron model initialization from HuggingFace checkpoints.""" +from dataclasses import fields +from typing import Any, Dict, Optional, Type + +import torch +import torch.distributed as dist +import torch.nn as nn +from packaging import version + +# Direct imports - assume megatron is installed +import megatron.core +from megatron.core import parallel_state as mpu +from megatron.core.models.gpt import GPTModel +from megatron.core.transformer import TransformerConfig + +from ..utils import convert_hf_config + +mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') + + +def _get_transformer_config_fields() -> set: + """Get valid field names for TransformerConfig. + + Returns: + Set of valid field names. + """ + return {f.name for f in fields(TransformerConfig)} + + +class MegatronModelInitializer: + """Initialize Megatron-Core models from HuggingFace checkpoints. + + This class handles: + - Converting HuggingFace config to Megatron TransformerConfig + - Creating Megatron model architecture + - Loading HuggingFace weights into Megatron model + """ + def __init__( + self, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + etp_size: Optional[int] = None, + vp_size: Optional[int] = None, + sequence_parallel: bool = False, + params_dtype: torch.dtype = torch.bfloat16, + use_cpu_initialization: bool = True, + ): + """Initialize MegatronModelInitializer. + + Args: + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + cp_size: Context parallel size. + ep_size: Expert parallel size. + etp_size: Expert tensor parallel size (defaults to tp_size). + vp_size: Virtual pipeline parallel size. + sequence_parallel: Enable sequence parallelism. + params_dtype: Parameter data type. + use_cpu_initialization: Initialize model on CPU first. + """ + self.tp_size = tp_size + self.pp_size = pp_size + self.cp_size = cp_size + self.ep_size = ep_size + self.etp_size = etp_size or tp_size + self.vp_size = vp_size + self.sequence_parallel = sequence_parallel + self.params_dtype = params_dtype + self.use_cpu_initialization = use_cpu_initialization + + # Cache valid TransformerConfig fields + self._valid_config_fields = _get_transformer_config_fields() + + def create_transformer_config( + self, + hf_config: Any, + **overrides, + ) -> 'TransformerConfig': + """Create Megatron TransformerConfig from HuggingFace config. + + Args: + hf_config: HuggingFace model config. + **overrides: Config overrides. + + Returns: + Megatron TransformerConfig. + """ + # Convert HuggingFace config to dict + mg_config_dict = convert_hf_config(hf_config) + + # Apply overrides + mg_config_dict.update(overrides) + + # Build config kwargs with only valid fields + config_kwargs = { + # Required fields + 'num_layers': mg_config_dict['num_layers'], + 'hidden_size': mg_config_dict['hidden_size'], + 'num_attention_heads': mg_config_dict['num_attention_heads'], + # Parallel settings + 'tensor_model_parallel_size': self.tp_size, + 'pipeline_model_parallel_size': self.pp_size, + 'context_parallel_size': self.cp_size, + 'expert_model_parallel_size': self.ep_size, + 'sequence_parallel': self.sequence_parallel, + 'params_dtype': self.params_dtype, + 'use_cpu_initialization': self.use_cpu_initialization, + } + + # Optional fields - only add if valid for this Megatron version + optional_fields = { + 'num_query_groups': + mg_config_dict.get('num_query_groups', + mg_config_dict['num_attention_heads']), + 'ffn_hidden_size': + mg_config_dict.get('ffn_hidden_size', + 4 * mg_config_dict['hidden_size']), + 'num_moe_experts': + mg_config_dict.get('num_experts'), + 'moe_router_topk': + mg_config_dict.get('moe_router_topk', 2) + if mg_config_dict.get('num_experts') else None, + 'layernorm_epsilon': + mg_config_dict.get('norm_epsilon', 1e-6), + 'add_qkv_bias': + mg_config_dict.get('add_qkv_bias', False), + 'add_bias_linear': + not mg_config_dict.get('disable_bias_linear', True), + 'gated_linear_unit': + mg_config_dict.get('swiglu', True), + 'qk_layernorm': + mg_config_dict.get('qk_layernorm', False), + 'normalization': + 'RMSNorm', + } + + # Add optional fields that are valid for this Megatron version + for key, value in optional_fields.items(): + if key in self._valid_config_fields and value is not None: + config_kwargs[key] = value + + # Store rotary settings for GPTModel (not TransformerConfig) + self._rotary_base = mg_config_dict.get('rotary_base', 10000) + self._rotary_percent = mg_config_dict.get('rotary_percent', 1.0) + self._position_embedding_type = mg_config_dict.get( + 'position_embedding_type', 'rope') + + # Create TransformerConfig + config = TransformerConfig(**config_kwargs) + + return config + + def create_gpt_model( + self, + hf_config: Any, + vocab_size: Optional[int] = None, + max_sequence_length: Optional[int] = None, + **config_overrides, + ) -> 'GPTModel': + """Create Megatron GPT model from HuggingFace config. + + Args: + hf_config: HuggingFace model config. + vocab_size: Override vocab size. + max_sequence_length: Override max sequence length. + **config_overrides: Config overrides. + + Returns: + Megatron GPTModel. + """ + # Create config (also sets self._rotary_base, etc.) + config = self.create_transformer_config(hf_config, **config_overrides) + + # Get vocab size + if vocab_size is None: + vocab_size = hf_config.vocab_size + + # Pad vocab size for tensor parallelism + padded_vocab_size = self._pad_vocab_size(vocab_size) + + # Get max sequence length + if max_sequence_length is None: + max_sequence_length = getattr(hf_config, 'max_position_embeddings', + 4096) + + # Get tie_word_embeddings setting + tie_word_embeddings = getattr(hf_config, 'tie_word_embeddings', False) + + # Create model with rotary settings passed directly to GPTModel + model = GPTModel( + config=config, + transformer_layer_spec=self._get_layer_spec(config), + vocab_size=padded_vocab_size, + max_sequence_length=max_sequence_length, + pre_process=mpu.is_pipeline_first_stage(), + post_process=mpu.is_pipeline_last_stage(), + parallel_output=True, + share_embeddings_and_output_weights=tie_word_embeddings, + position_embedding_type=self._position_embedding_type, + rotary_percent=self._rotary_percent, + rotary_base=self._rotary_base, + ) + + return model + + def _pad_vocab_size(self, vocab_size: int) -> int: + """Pad vocab size for tensor parallelism. + + Args: + vocab_size: Original vocab size. + + Returns: + Padded vocab size. + """ + # Pad to multiple of tp_size * 128 for efficient parallelism + divisor = self.tp_size * 128 + return ((vocab_size + divisor - 1) // divisor) * divisor + + def _get_layer_spec(self, config: 'TransformerConfig'): + """Get transformer layer specification. + + Args: + config: Transformer config. + + Returns: + Layer specification (ModuleSpec or TransformerBlockSubmodules). + """ + from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, + get_gpt_layer_local_spec, + ) + + # Determine if this is a MoE model + num_experts = getattr(config, 'num_moe_experts', None) + moe_grouped_gemm = getattr(config, 'moe_grouped_gemm', False) + qk_layernorm = getattr(config, 'qk_layernorm', False) + multi_latent_attention = getattr(config, 'multi_latent_attention', + False) + + # Try TE (TransformerEngine) layers first for better performance + try: + return get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=qk_layernorm, + multi_latent_attention=multi_latent_attention, + ) + except (ImportError, AttributeError): + raise RuntimeError("TransformerEngine is not installed or not compatible with this version of Megatron-Core.") + + def load_from_hf( + self, + model: nn.Module, + hf_model_path: str, + hf_config: Any, + ) -> None: + """Load HuggingFace checkpoint into Megatron model. + + Args: + model: The Megatron model. + hf_model_path: Path to HuggingFace checkpoint or model ID. + hf_config: HuggingFace model config. + """ + import os + + # Resolve model path if it's a model ID (not a local path) + if not os.path.isdir(hf_model_path): + from twinkle.hub import HubOperation + hf_model_path = HubOperation.download_model(hf_model_path) + + # Calculate padded vocab size + padded_vocab_size = self._pad_vocab_size(hf_config.vocab_size) + + # Use TwinkleBridgeAdapter + from .bridge import TwinkleBridgeAdapter + adapter = TwinkleBridgeAdapter( + hf_config=hf_config, + tp_size=self.tp_size, + pp_size=self.pp_size, + ep_size=self.ep_size, + model_path=hf_model_path, + padded_vocab_size=padded_vocab_size, + ) + adapter.load_weights(model, hf_model_path) + + +def initialize_megatron_model( + hf_model_path: str, + tp_size: int = 1, + pp_size: int = 1, + cp_size: int = 1, + ep_size: int = 1, + params_dtype: torch.dtype = torch.bfloat16, + load_weights: bool = True, +) -> nn.Module: + """Convenience function to initialize Megatron model from HuggingFace checkpoint. + + Args: + hf_model_path: Path to HuggingFace checkpoint. + tp_size: Tensor parallel size. + pp_size: Pipeline parallel size. + cp_size: Context parallel size. + ep_size: Expert parallel size. + params_dtype: Parameter data type. + load_weights: Whether to load weights. + + Returns: + Initialized Megatron model. + """ + from transformers import AutoConfig + + # Load HuggingFace config + hf_config = AutoConfig.from_pretrained(hf_model_path) + + # Create initializer + initializer = MegatronModelInitializer( + tp_size=tp_size, + pp_size=pp_size, + cp_size=cp_size, + ep_size=ep_size, + params_dtype=params_dtype, + ) + + # Create model + model = initializer.create_gpt_model(hf_config) + + # Load weights + if load_weights: + initializer.load_from_hf(model, hf_model_path, hf_config) + + return model diff --git a/src/twinkle/megatron/model/qwen3.py b/src/twinkle/megatron/model/qwen3.py new file mode 100644 index 00000000..0acd3f67 --- /dev/null +++ b/src/twinkle/megatron/model/qwen3.py @@ -0,0 +1,65 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Qwen3 model metadata for Megatron-Core. + +This module provides metadata for Qwen3 models. +""" +from typing import Any, Dict + + +# ============================================================================= +# Qwen3 Model Metadata +# ============================================================================= +class Qwen3ModelMeta: + """Metadata for Qwen3 models.""" + + # Supported architectures + DENSE_ARCHITECTURES = [ + 'Qwen3ForCausalLM', 'Qwen2ForCausalLM', 'Qwen2.5ForCausalLM' + ] + MOE_ARCHITECTURES = ['Qwen3MoeForCausalLM', 'Qwen2MoeForCausalLM'] + ALL_ARCHITECTURES = DENSE_ARCHITECTURES + MOE_ARCHITECTURES + + # HuggingFace key prefixes + HF_LAYERS_PREFIX = 'model.layers' + HF_EMBED_KEY = 'model.embed_tokens.weight' + HF_FINAL_LAYERNORM_KEY = 'model.norm.weight' + HF_LM_HEAD_KEY = 'lm_head.weight' + + # Qwen3 specific settings + DEFAULT_CONFIG = { + 'qk_layernorm': True, + 'swiglu': True, + 'disable_bias_linear': True, + 'rotary_interleaved': False, + } + + # MoE specific settings + MOE_CONFIG = { + 'use_shared_expert_gate': True, + } + + @classmethod + def is_qwen3(cls, architecture: str) -> bool: + """Check if architecture is a Qwen3 model.""" + return architecture in cls.ALL_ARCHITECTURES + + @classmethod + def is_qwen3_moe(cls, architecture: str) -> bool: + """Check if architecture is a Qwen3 MoE model.""" + return architecture in cls.MOE_ARCHITECTURES + + +def get_model_default_config(architecture: str) -> Dict[str, Any]: + """Get default config overrides for a model architecture. + + Args: + architecture: Model architecture name. + + Returns: + Default config dict for Megatron TransformerConfig. + """ + if Qwen3ModelMeta.is_qwen3_moe(architecture): + return {**Qwen3ModelMeta.DEFAULT_CONFIG, **Qwen3ModelMeta.MOE_CONFIG} + elif Qwen3ModelMeta.is_qwen3(architecture): + return Qwen3ModelMeta.DEFAULT_CONFIG + return {} diff --git a/src/twinkle/megatron/tuners/__init__.py b/src/twinkle/megatron/tuners/__init__.py new file mode 100644 index 00000000..c6ea530f --- /dev/null +++ b/src/twinkle/megatron/tuners/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-compatible tuners for efficient fine-tuning.""" + +from .lora import LoraParallelLinear, dispatch_megatron + +__all__ = [ + 'LoraParallelLinear', + 'dispatch_megatron', +] diff --git a/src/twinkle/megatron/tuners/lora.py b/src/twinkle/megatron/tuners/lora.py new file mode 100644 index 00000000..a9d29b73 --- /dev/null +++ b/src/twinkle/megatron/tuners/lora.py @@ -0,0 +1,642 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-compatible LoRA implementation with Tensor Parallel support.""" +import math +import warnings +from contextlib import contextmanager +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version +from peft.tuners.lora import model +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge +from peft.utils.other import transpose + +# Direct imports - assume megatron and peft are installed +import megatron.core +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEGroupedLinear, + TELayerNormColumnParallelLinear, TELinear, TERowParallelGroupedLinear, + TERowParallelLinear) +from megatron.core.models.common.embeddings.language_model_embedding import \ + LanguageModelEmbedding +from megatron.core.parallel_state import ( + get_expert_tensor_parallel_world_size, + get_tensor_model_parallel_world_size) +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, scatter_to_sequence_parallel_region) +from megatron.core.transformer.mlp import apply_swiglu_sharded_factory +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.moe.router import TopKRouter + +mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') + + +class LoraParallelLinear(MegatronModule, LoraLayer): + """LoRA layer compatible with Megatron Tensor Parallel Linear layers. + + This class wraps Megatron's parallel linear layers (TELinear, TEColumnParallelLinear, + TERowParallelLinear, etc.) and adds LoRA adapters that are correctly sharded + across tensor parallel ranks. + """ + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + fan_in_fan_out: bool = False, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + lora_bias: bool = False, + **kwargs, + ): + """Initialize LoraParallelLinear. + + Args: + base_layer: The Megatron parallel linear layer to wrap. + adapter_name: Name of the LoRA adapter. + r: LoRA rank. + lora_alpha: LoRA alpha scaling factor. + lora_dropout: Dropout probability for LoRA layers. + fan_in_fan_out: Whether the layer uses fan-in/fan-out convention. + init_lora_weights: Whether to initialize LoRA weights. + use_rslora: Use rank-stabilized LoRA scaling. + use_dora: Use DoRA (not supported yet). + lora_bias: Whether to add bias to LoRA layers. + """ + config = base_layer.config + super().__init__(config=config) + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + LoraLayer.__init__(self, base_layer=base_layer) + + if use_dora: + raise ValueError( + f'{self.__class__.__name__} does not support DoRA yet, please set it to False' + ) + + self.is_parallel_a = isinstance( + base_layer, (TERowParallelLinear, TERowParallelGroupedLinear)) + self.is_grouped = isinstance(base_layer, TEGroupedLinear) + self.fan_in_fan_out = fan_in_fan_out + self._active_adapter = adapter_name + self.is_expert = getattr(base_layer, 'is_expert', False) + self.sequence_parallel = getattr(base_layer, 'sequence_parallel', + False) + + if self.is_expert: + self.tp_size = get_expert_tensor_parallel_world_size() + else: + self.tp_size = get_tensor_model_parallel_world_size() + + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + lora_bias=lora_bias, + ) + + self.is_target_conv_1d_layer = False + + def update_layer(self, adapter_name: str, r: int, *, lora_alpha: int, + lora_dropout: float, init_lora_weights: bool, + use_rslora: bool, lora_bias: bool, **kwargs): + """Update LoRA layer with new adapter configuration. + + Args: + adapter_name: Name of the adapter. + r: LoRA rank. + lora_alpha: LoRA alpha scaling factor. + lora_dropout: Dropout probability. + init_lora_weights: Whether to initialize weights. + use_rslora: Use rank-stabilized LoRA. + lora_bias: Whether to add bias. + """ + if r <= 0: + raise ValueError( + f'`r` should be a positive integer value but the value passed is {r}' + ) + + self.r[adapter_name] = r + self.lora_alpha[adapter_name] = lora_alpha + + if lora_dropout > 0.0: + lora_dropout_layer = nn.Dropout(p=lora_dropout) + else: + lora_dropout_layer = nn.Identity() + + self.lora_dropout[adapter_name] = lora_dropout_layer + + # Build LoRA A and B matrices with proper parallelism + kwargs = { + 'skip_bias_add': False, + 'init_method': self.config.init_method, + 'config': self.config, + 'is_expert': self.is_expert, + } + if mcore_013: + kwargs['tp_group'] = self.base_layer.tp_group + + if isinstance(self.base_layer, TopKRouter): + # Router layer - no parallelism needed + router_shape = self.base_layer.weight.shape + lora_a = TELinear( + input_size=router_shape[1], + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_b = TELinear( + input_size=r, + output_size=router_shape[0], + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + elif self.is_parallel_a: + # Row parallel layer - LoRA A is parallel, LoRA B is not + in_features = self.in_features * self.tp_size + if self.is_grouped: + lora_a = TERowParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=in_features, + output_size=r, + bias=False, + **kwargs, + ) + lora_b = TEGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + **kwargs, + ) + else: + lora_a = TERowParallelLinear( + input_size=in_features, + output_size=r, + bias=False, + input_is_parallel=True, + **kwargs, + ) + lora_b = TELinear( + input_size=r, + output_size=self.out_features, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs, + ) + lora_a.parallel_mode = self.base_layer.parallel_mode + else: + # Column parallel layer - LoRA A is not parallel, LoRA B is parallel + out_features = self.out_features * self.tp_size + if self.is_grouped: + lora_a = TEGroupedLinear(num_gemms=self.base_layer.num_gemms, + input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + **kwargs) + lora_b = TEColumnParallelGroupedLinear( + num_gemms=self.base_layer.num_gemms, + input_size=r, + output_size=out_features, + bias=lora_bias, + **kwargs, + ) + else: + lora_a = TELinear(input_size=self.in_features, + output_size=r, + bias=lora_bias, + parallel_mode=None, + skip_weight_param_allocation=False, + **kwargs) + lora_b = TEColumnParallelLinear( + input_size=r, + output_size=out_features, + bias=lora_bias, + gather_output=False, + **kwargs, + ) + lora_b.parallel_mode = self.base_layer.parallel_mode + + # Disable overlap for LoRA layers + for lora in [lora_a, lora_b]: + if isinstance( + lora, + (TERowParallelLinear, + TEColumnParallelLinear)) and lora.parallel_mode is None: + lora.ub_overlap_rs_fprop = False + lora.ub_overlap_ag_dgrad = False + lora.ub_overlap_ag_fprop = False + lora.ub_overlap_rs_dgrad = False + + lora_a.sequence_parallel = False + lora_b.sequence_parallel = False + + self.lora_A[adapter_name] = lora_a + self.lora_B[adapter_name] = lora_b + + if hasattr(self, 'lora_bias'): + self.lora_bias[adapter_name] = lora_bias + + if use_rslora: + self.scaling[adapter_name] = lora_alpha / (r**0.5) + else: + self.scaling[adapter_name] = lora_alpha / r + + if init_lora_weights: + self.reset_lora_parameters(adapter_name, init_lora_weights) + + self._move_adapter_to_device_of_base_layer(adapter_name) + self.set_adapter(self.active_adapters) + + def reset_lora_parameters(self, adapter_name: str, + init_lora_weights: bool): + """Reset LoRA parameters to initial values. + + Args: + adapter_name: Name of the adapter. + init_lora_weights: Initialization method. + """ + if init_lora_weights is False: + return + + if adapter_name in self.lora_A.keys(): + lora_a = self.lora_A[adapter_name] + lora_b = self.lora_B[adapter_name] + + if isinstance(lora_a, TEGroupedLinear): + weights_a = [ + getattr(lora_a, f'weight{i}') + for i in range(lora_a.num_gemms) + ] + else: + weights_a = [lora_a.weight] + + if isinstance(lora_b, TEGroupedLinear): + weights_b = [ + getattr(lora_b, f'weight{i}') + for i in range(lora_b.num_gemms) + ] + else: + weights_b = [lora_b.weight] + + for weight_a in weights_a: + if init_lora_weights is True: + nn.init.kaiming_uniform_(weight_a, a=math.sqrt(5)) + elif init_lora_weights.lower() == 'gaussian': + nn.init.normal_(weight_a, std=1 / self.r[adapter_name]) + else: + raise ValueError( + f'Unknown initialization {init_lora_weights=}') + + for weight_b in weights_b: + nn.init.zeros_(weight_b) + + if adapter_name in self.lora_embedding_A.keys(): + nn.init.zeros_(self.lora_embedding_A[adapter_name]) + nn.init.normal_(self.lora_embedding_B[adapter_name]) + + @contextmanager + def _patch_router_gating(self): + """Context manager to patch router gating with LoRA.""" + origin_gating = self.base_layer.__class__.gating + + def gating(_self, x): + result = origin_gating(_self, x) + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + x = x.to(result.dtype) + + lora_result = F.linear(dropout(x), + lora_A.weight.to(result.dtype)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = F.linear(lora_result, + lora_B.weight.to(result.dtype)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + lora_result = lora_result * scaling + + result = result + lora_result + return result + + self.base_layer.__class__.gating = gating + try: + yield + finally: + self.base_layer.__class__.gating = origin_gating + + def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any): + """Forward pass with LoRA adaptation. + + Args: + x: Input tensor. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + Tuple of (output tensor, bias). + """ + previous_dtype = x.dtype + if self.disable_adapters and self.merged: + self.unmerge() + + if isinstance(self.base_layer, TELayerNormColumnParallelLinear): + if self.disable_adapters or self.merged: + self.base_layer.return_layernorm_output = False + result, bias = self.base_layer(x, *args, **kwargs) + else: + self.base_layer.return_layernorm_output = True + (result, x), bias = self.base_layer(x, *args, **kwargs) + elif isinstance(self.base_layer, (TELinear, TEGroupedLinear)): + result, bias = self.base_layer(x, *args, **kwargs) + elif isinstance(self.base_layer, TopKRouter): + with self._patch_router_gating(): + result, bias = self.base_layer(x, *args, **kwargs) + else: + raise ValueError( + f'Unsupported base layer type: {type(self.base_layer)}') + + if not isinstance( + self.base_layer, + TopKRouter) and not self.disable_adapters and not self.merged: + if self.sequence_parallel and self.base_layer.parallel_mode == 'column': + x = gather_from_sequence_parallel_region(x) + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + dtype = lora_A.weight0.dtype if isinstance( + lora_A, TEGroupedLinear) else lora_A.weight.dtype + x = x.to(dtype) + + lora_result = lora_A( + dropout(x), *args, **kwargs) if isinstance( + lora_A, TEGroupedLinear) else lora_A(dropout(x)) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + + lora_result = lora_B( + lora_result, *args, **kwargs) if isinstance( + lora_B, TEGroupedLinear) else lora_B(lora_result) + if isinstance(lora_result, tuple): + lora_result = lora_result[0] + + lora_result = lora_result * scaling + + if self.sequence_parallel and self.base_layer.parallel_mode == 'row': + lora_result = scatter_to_sequence_parallel_region( + lora_result) + + result = result + lora_result + + result = result.to(previous_dtype) + return result, bias + + def sharded_state_dict( + self, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, + ) -> ShardedStateDict: + """Get sharded state dict for distributed checkpointing. + + Args: + prefix: Key prefix. + sharded_offsets: Sharding offsets. + metadata: Additional metadata. + + Returns: + Sharded state dictionary. + """ + from ..utils import tuners_sharded_state_dict + + sharded_state_dict = tuners_sharded_state_dict(self, prefix, + sharded_offsets, + metadata) + + if prefix.endswith('linear_fc1.'): + if isinstance(self.base_layer, + TEGroupedLinear) and self.config.gated_linear_unit: + num_global_experts = ( + parallel_state.get_expert_model_parallel_world_size() * + self.base_layer.num_gemms) + local_expert_indices_offset = ( + parallel_state.get_expert_model_parallel_rank() * + self.base_layer.num_gemms) + ep_axis = len(sharded_offsets) + for i in range(self.base_layer.num_gemms): + new_sharded_offsets = ( + *sharded_offsets, + (ep_axis, local_expert_indices_offset + i, + num_global_experts), + ) + for k in (f'{prefix}base_layer.weight{i}', + f'{prefix}base_layer.bias{i}'): + if k in sharded_state_dict: + sharded_state_dict[ + k] = apply_swiglu_sharded_factory( + sharded_state_dict[k], new_sharded_offsets) + else: + for k, v in sharded_state_dict.items(): + if k in [ + f'{prefix}base_layer.weight', + f'{prefix}base_layer.bias' + ]: + sharded_state_dict[k] = apply_swiglu_sharded_factory( + sharded_state_dict[k], sharded_offsets) + return sharded_state_dict + + def get_delta_weights(self, adapter: str) -> List[torch.Tensor]: + """Compute the delta weight for the given adapter. + + Args: + adapter: The name of the adapter. + + Returns: + List of delta weight tensors. + """ + lora_A = self.lora_A[adapter] + lora_B = self.lora_B[adapter] + + if self.is_grouped: + weight_A = [ + getattr(lora_A, f'weight{i}') for i in range(lora_A.num_gemms) + ] + weight_B = [ + getattr(lora_B, f'weight{i}') for i in range(lora_B.num_gemms) + ] + else: + weight_A = [self.lora_A[adapter].weight] + weight_B = [self.lora_B[adapter].weight] + + output_tensor = [] + assert len(weight_A) == len(weight_B) + + for i in range(len(weight_B)): + output_tensor.append( + transpose(weight_B[i] @ weight_A[i], self.fan_in_fan_out) * + self.scaling[adapter]) + + return output_tensor + + def merge(self, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None) -> None: + """Merge the active adapter weights into the base weights. + + Args: + safe_merge: If True, check for NaNs before merging. + adapter_names: List of adapter names to merge. + """ + adapter_names = check_adapters_to_merge(self, adapter_names) + if not adapter_names: + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + + if origin_device.type == 'cpu': + device = torch.cuda.current_device() if torch.cuda.is_available( + ) else 'cpu' + self.to(device=device) + + for active_adapter in adapter_names: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [ + getattr(base_layer, f'weight{i}') + for i in range(base_layer.num_gemms) + ] + else: + orig_weights = [base_layer.weight] + + if safe_merge: + orig_weights = [ + weight.data.clone() for weight in orig_weights + ] + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip( + orig_weights, delta_weights): + orig_weight += delta_weight + if not all( + torch.isfinite(orig_weights[i]).all() + for i in range(len(orig_weights))): + raise ValueError( + f'NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken' + ) + if self.is_grouped: + for i in range(base_layer.num_gemms): + weight = getattr(base_layer, f'weight{i}') + weight.data = orig_weights[i] + else: + base_layer.weight.data = orig_weights[0] + else: + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip( + orig_weights, delta_weights): + orig_weight.data += delta_weight + + self.merged_adapters.append(active_adapter) + + if origin_device.type == 'cpu': + self.to(device=origin_device) + + def unmerge(self) -> None: + """Unmerge all merged adapter weights from the base weights.""" + if not self.merged: + return + + base_layer = self.get_base_layer() + origin_device = base_layer.weight0.device if self.is_grouped else base_layer.weight.device + + if origin_device.type == 'cpu': + device = torch.cuda.current_device() if torch.cuda.is_available( + ) else 'cpu' + self.to(device=device) + + for active_adapter in self.merged_adapters: + if active_adapter in self.lora_A.keys(): + if self.is_grouped: + orig_weights = [ + getattr(base_layer, f'weight{i}') + for i in range(base_layer.num_gemms) + ] + else: + orig_weights = [base_layer.weight] + + delta_weights = self.get_delta_weights(active_adapter) + for orig_weight, delta_weight in zip(orig_weights, + delta_weights): + orig_weight.data -= delta_weight + + self.merged_adapters = [] + + if origin_device.type == 'cpu': + self.to(device=origin_device) + + +def dispatch_megatron( + target: torch.nn.Module, + adapter_name: str, + lora_config, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + """Dispatch function to replace Megatron linear layers with LoRA layers. + + Args: + target: The target module to potentially replace. + adapter_name: Name of the LoRA adapter. + lora_config: LoRA configuration. + **kwargs: Additional arguments for LoraParallelLinear. + + Returns: + LoraParallelLinear if target is a compatible layer, None otherwise. + """ + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + linear_cls = (TELayerNormColumnParallelLinear, TELinear, TEGroupedLinear, + TopKRouter) + if isinstance(target_base_layer, linear_cls): + new_module = LoraParallelLinear(base_layer=target, + adapter_name=adapter_name, + **kwargs) + + return new_module + + +# Register dispatch function with PEFT +model.dispatch_megatron = dispatch_megatron diff --git a/src/twinkle/megatron/utils.py b/src/twinkle/megatron/utils.py new file mode 100644 index 00000000..7f465230 --- /dev/null +++ b/src/twinkle/megatron/utils.py @@ -0,0 +1,1036 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Utility functions for Megatron-Core integration.""" +from contextlib import contextmanager +from typing import Any, Dict, List, Optional, Tuple +import threading + +import torch +import torch.nn as nn +import torch.distributed as dist + +import megatron.core +from megatron.core import parallel_state as mpu +from megatron.core.extensions.transformer_engine import ( + TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear +) +from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding +from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default +from packaging import version +from peft import LoraConfig, get_peft_model + +mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0') + + +# Config mapping from HuggingFace to Megatron +CONFIG_MAPPING = { + 'num_layers': ['num_hidden_layers'], + 'hidden_size': ['hidden_size'], + 'mlp_ffn_hidden_size': ['intermediate_size_mlp'], + 'ffn_hidden_size': ['intermediate_size'], + 'num_attention_heads': ['num_attention_heads'], + 'num_query_groups': ['num_key_value_heads'], + 'max_position_embeddings': ['max_position_embeddings'], + 'norm_epsilon': ['rms_norm_eps'], + 'rotary_base': ['rope_theta'], + 'padded_vocab_size': ['vocab_size'], + 'attention_dropout': ['attention_dropout'], + 'untie_embeddings_and_output_weights': ['tie_word_embeddings'], + 'swiglu': ['hidden_act'], + 'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'], + 'disable_bias_linear': ['mlp_bias'], + 'kv_channels': ['head_dim', 'v_head_dim'], + 'architectures': ['architectures'], + # moe + 'moe_ffn_hidden_size': ['moe_intermediate_size'], + 'moe_shared_expert_intermediate_size': ['shared_expert_intermediate_size'], + 'moe_router_topk': ['num_experts_per_tok', 'moe_topk', 'moe_k'], + 'moe_router_num_groups': ['n_group'], + 'moe_router_group_topk': ['topk_group'], + 'num_experts': ['num_experts', 'n_routed_experts', 'moe_num_experts', 'num_local_experts'], + 'moe_router_pre_softmax': ['norm_topk_prob'], + # deepseek + 'q_lora_rank': ['q_lora_rank'], + 'kv_lora_rank': ['kv_lora_rank'], + 'moe_router_score_function': ['scoring_func'], + 'moe_router_bias_update_rate': ['aux_loss_alpha'], + 'qk_head_dim': ['qk_nope_head_dim'], + 'qk_pos_emb_head_dim': ['qk_rope_head_dim'], + 'moe_router_topk_scaling_factor': ['routed_scaling_factor'], + 'qk_layernorm': ['use_qk_norm'], + # other + 'original_max_position_embeddings': ['original_max_position_embeddings'], + 'partial_rotary_factor': ['partial_rotary_factor'], + 'first_k_dense_replace': ['first_k_dense_replace', 'moe_layer_start_index'], + 'n_shared_experts': ['n_shared_experts', 'num_shared_expert', 'moe_num_shared_experts'], + 'window_size': ['sliding_window'], + 'layer_types': ['layer_types'], +} + + +class TenantProcessGroupManager: + """Manager for multi-tenant process groups. + + In a multi-tenant scenario, multiple users may share the same base model in a single + process, each with their own LoRA adapters. To avoid communication interference between + tenants, we need to maintain separate process groups for each tenant. + + This class provides: + 1. Per-tenant process group isolation + 2. Context managers to temporarily switch active process groups + 3. Patching of Megatron's communication operations to use tenant-specific groups + + Example: + # Create tenant-specific groups + manager = TenantProcessGroupManager() + manager.register_tenant('user_1', tp_ranks=[0, 1], dp_ranks=[0, 2]) + manager.register_tenant('user_2', tp_ranks=[2, 3], dp_ranks=[1, 3]) + + # Use tenant context for operations + with manager.tenant_context('user_1'): + # All Megatron communications will use user_1's process groups + model.forward(...) + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """Singleton pattern for global access.""" + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self._initialized = True + + # Tenant ID -> Process Groups mapping + self._tenant_groups: Dict[str, Dict[str, dist.ProcessGroup]] = {} + # Current active tenant (thread-local) + self._active_tenant = threading.local() + # Original Megatron parallel state functions (for patching) + self._original_functions = {} + # Whether patching is active + self._patched = False + + def register_tenant( + self, + tenant_id: str, + tp_ranks: Optional[List[int]] = None, + pp_ranks: Optional[List[int]] = None, + dp_ranks: Optional[List[int]] = None, + ep_ranks: Optional[List[int]] = None, + cp_ranks: Optional[List[int]] = None, + ) -> None: + """Register a tenant with specific process group ranks. + + Args: + tenant_id: Unique identifier for the tenant. + tp_ranks: Ranks for tensor parallel group. + pp_ranks: Ranks for pipeline parallel group. + dp_ranks: Ranks for data parallel group. + ep_ranks: Ranks for expert parallel group. + cp_ranks: Ranks for context parallel group. + """ + if tenant_id in self._tenant_groups: + return # Already registered + + groups = {} + + # Create process groups for each parallelism dimension + if tp_ranks and len(tp_ranks) > 1: + groups['tp'] = dist.new_group(tp_ranks) + if pp_ranks and len(pp_ranks) > 1: + groups['pp'] = dist.new_group(pp_ranks) + if dp_ranks and len(dp_ranks) > 1: + groups['dp'] = dist.new_group(dp_ranks) + if ep_ranks and len(ep_ranks) > 1: + groups['ep'] = dist.new_group(ep_ranks) + if cp_ranks and len(cp_ranks) > 1: + groups['cp'] = dist.new_group(cp_ranks) + + self._tenant_groups[tenant_id] = groups + + def unregister_tenant(self, tenant_id: str) -> None: + """Unregister a tenant and destroy its process groups. + + Args: + tenant_id: Tenant to unregister. + """ + if tenant_id in self._tenant_groups: + groups = self._tenant_groups.pop(tenant_id) + for group in groups.values(): + dist.destroy_process_group(group) + + def get_tenant_group(self, tenant_id: str, group_type: str) -> Optional[dist.ProcessGroup]: + """Get process group for a tenant. + + Args: + tenant_id: Tenant identifier. + group_type: Type of group ('tp', 'pp', 'dp', 'ep', 'cp'). + + Returns: + Process group or None if not found. + """ + if tenant_id in self._tenant_groups: + return self._tenant_groups[tenant_id].get(group_type) + return None + + @property + def active_tenant(self) -> Optional[str]: + """Get the currently active tenant ID.""" + return getattr(self._active_tenant, 'id', None) + + @contextmanager + def tenant_context(self, tenant_id: str): + """Context manager to set active tenant for communications. + + All Megatron communication operations within this context will use + the tenant-specific process groups. This includes: + + - Tensor Parallel (TP): get_tensor_model_parallel_group/rank/world_size + - Data Parallel (DP): get_data_parallel_group/rank/world_size + - Pipeline Parallel (PP): get_pipeline_model_parallel_group/rank/world_size, + is_pipeline_first_stage, is_pipeline_last_stage + - Expert Parallel (EP): get_expert_model_parallel_group/rank/world_size + - Context Parallel (CP): get_context_parallel_group/rank/world_size + + Args: + tenant_id: Tenant to activate. + + Example: + manager = get_tenant_manager() + manager.register_tenant('user_1', tp_ranks=[0, 1], dp_ranks=[0, 2]) + + with manager.tenant_context('user_1'): + # All Megatron communications use user_1's groups + output = model.forward(input_ids) + """ + old_tenant = self.active_tenant + self._active_tenant.id = tenant_id + + # Apply all patches if not already done + if not self._patched: + self._patch_megatron_parallel_state() + self._patch_tensor_parallel_comms() + self._patch_expert_parallel_comms() + self._patch_context_parallel_comms() + + try: + yield + finally: + self._active_tenant.id = old_tenant + + def _patch_megatron_parallel_state(self) -> None: + """Patch Megatron's parallel_state to use tenant-specific groups. + + This patches the following functions for full TP/PP/DP/EP/CP support: + - get_tensor_model_parallel_group / get_tensor_model_parallel_world_size / get_tensor_model_parallel_rank + - get_data_parallel_group / get_data_parallel_world_size / get_data_parallel_rank + - get_pipeline_model_parallel_group / get_pipeline_model_parallel_world_size / get_pipeline_model_parallel_rank + - get_expert_model_parallel_group / get_expert_model_parallel_world_size / get_expert_model_parallel_rank + - get_context_parallel_group / get_context_parallel_world_size / get_context_parallel_rank + """ + if self._patched: + return + + # Save original functions + self._original_functions = { + # TP functions + 'get_tensor_model_parallel_group': mpu.get_tensor_model_parallel_group, + 'get_tensor_model_parallel_world_size': mpu.get_tensor_model_parallel_world_size, + 'get_tensor_model_parallel_rank': mpu.get_tensor_model_parallel_rank, + # DP functions + 'get_data_parallel_group': mpu.get_data_parallel_group, + 'get_data_parallel_world_size': mpu.get_data_parallel_world_size, + 'get_data_parallel_rank': mpu.get_data_parallel_rank, + # PP functions + 'get_pipeline_model_parallel_group': mpu.get_pipeline_model_parallel_group, + 'get_pipeline_model_parallel_world_size': mpu.get_pipeline_model_parallel_world_size, + 'get_pipeline_model_parallel_rank': mpu.get_pipeline_model_parallel_rank, + 'is_pipeline_first_stage': mpu.is_pipeline_first_stage, + 'is_pipeline_last_stage': mpu.is_pipeline_last_stage, + # EP functions + 'get_expert_model_parallel_group': mpu.get_expert_model_parallel_group, + 'get_expert_model_parallel_world_size': mpu.get_expert_model_parallel_world_size, + 'get_expert_model_parallel_rank': mpu.get_expert_model_parallel_rank, + # CP functions + 'get_context_parallel_group': mpu.get_context_parallel_group, + 'get_context_parallel_world_size': mpu.get_context_parallel_world_size, + 'get_context_parallel_rank': mpu.get_context_parallel_rank, + } + + manager = self + + def _make_group_getter(group_type: str, original_func_name: str): + """Create patched group getter function.""" + def patched_func(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, group_type) + if group is not None: + return group + return manager._original_functions[original_func_name](*args, **kwargs) + return patched_func + + def _make_world_size_getter(group_type: str, original_func_name: str): + """Create patched world_size getter function.""" + def patched_func(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, group_type) + if group is not None: + return dist.get_world_size(group) + return manager._original_functions[original_func_name](*args, **kwargs) + return patched_func + + def _make_rank_getter(group_type: str, original_func_name: str): + """Create patched rank getter function.""" + def patched_func(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, group_type) + if group is not None: + return dist.get_rank(group) + return manager._original_functions[original_func_name](*args, **kwargs) + return patched_func + + # Apply patches for TP + mpu.get_tensor_model_parallel_group = _make_group_getter('tp', 'get_tensor_model_parallel_group') + mpu.get_tensor_model_parallel_world_size = _make_world_size_getter('tp', 'get_tensor_model_parallel_world_size') + mpu.get_tensor_model_parallel_rank = _make_rank_getter('tp', 'get_tensor_model_parallel_rank') + + # Apply patches for DP + mpu.get_data_parallel_group = _make_group_getter('dp', 'get_data_parallel_group') + mpu.get_data_parallel_world_size = _make_world_size_getter('dp', 'get_data_parallel_world_size') + mpu.get_data_parallel_rank = _make_rank_getter('dp', 'get_data_parallel_rank') + + # Apply patches for PP + mpu.get_pipeline_model_parallel_group = _make_group_getter('pp', 'get_pipeline_model_parallel_group') + mpu.get_pipeline_model_parallel_world_size = _make_world_size_getter('pp', 'get_pipeline_model_parallel_world_size') + mpu.get_pipeline_model_parallel_rank = _make_rank_getter('pp', 'get_pipeline_model_parallel_rank') + + # Patch is_pipeline_first/last_stage + def patched_is_pipeline_first_stage(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, 'pp') + if group is not None: + return dist.get_rank(group) == 0 + return manager._original_functions['is_pipeline_first_stage'](*args, **kwargs) + + def patched_is_pipeline_last_stage(*args, **kwargs): + tenant = manager.active_tenant + if tenant and tenant in manager._tenant_groups: + group = manager.get_tenant_group(tenant, 'pp') + if group is not None: + return dist.get_rank(group) == dist.get_world_size(group) - 1 + return manager._original_functions['is_pipeline_last_stage'](*args, **kwargs) + + mpu.is_pipeline_first_stage = patched_is_pipeline_first_stage + mpu.is_pipeline_last_stage = patched_is_pipeline_last_stage + + # Apply patches for EP + mpu.get_expert_model_parallel_group = _make_group_getter('ep', 'get_expert_model_parallel_group') + mpu.get_expert_model_parallel_world_size = _make_world_size_getter('ep', 'get_expert_model_parallel_world_size') + mpu.get_expert_model_parallel_rank = _make_rank_getter('ep', 'get_expert_model_parallel_rank') + + # Apply patches for CP + mpu.get_context_parallel_group = _make_group_getter('cp', 'get_context_parallel_group') + mpu.get_context_parallel_world_size = _make_world_size_getter('cp', 'get_context_parallel_world_size') + mpu.get_context_parallel_rank = _make_rank_getter('cp', 'get_context_parallel_rank') + + self._patched = True + + def unpatch_megatron_parallel_state(self) -> None: + """Restore original Megatron parallel_state functions.""" + if not self._patched: + return + + for name, func in self._original_functions.items(): + setattr(mpu, name, func) + + self._patched = False + self._original_functions = {} + + def _patch_tensor_parallel_comms(self) -> None: + """Patch tensor parallel communication operations. + + This patches critical TP communication functions in megatron.core.tensor_parallel: + - mappings.copy_to_tensor_model_parallel_region + - mappings.reduce_from_tensor_model_parallel_region + - mappings.scatter_to_tensor_model_parallel_region + - mappings.gather_from_tensor_model_parallel_region + """ + try: + from megatron.core.tensor_parallel import mappings + except ImportError: + return + + if hasattr(self, '_tp_comms_patched') and self._tp_comms_patched: + return + + # Save original functions + self._original_tp_functions = {} + + # The mappings module uses get_tensor_model_parallel_group() internally, + # which we've already patched. No additional patches needed here. + # The patched group getters will be used automatically. + + self._tp_comms_patched = True + + def _patch_expert_parallel_comms(self) -> None: + """Patch expert parallel communication operations for MoE models. + + For MoE models, expert parallel communications use: + - get_expert_model_parallel_group + - get_expert_tensor_parallel_group (if using expert tensor parallelism) + + Since we've patched the group getters, the communications will + automatically use tenant-specific groups. + """ + # Expert parallel communications use the patched group getters + # No additional patches needed + pass + + def _patch_context_parallel_comms(self) -> None: + """Patch context parallel communication operations. + + Context parallelism communications include: + - Ring attention communications + - CP all-to-all operations + + These use get_context_parallel_group() which we've patched. + """ + # CP communications use the patched group getters + # No additional patches needed + pass + + +# Global instance for easy access +_tenant_manager: Optional[TenantProcessGroupManager] = None + + +def get_tenant_manager() -> TenantProcessGroupManager: + """Get the global tenant process group manager. + + + Returns: + The singleton TenantProcessGroupManager instance. + """ + global _tenant_manager + if _tenant_manager is None: + _tenant_manager = TenantProcessGroupManager() + return _tenant_manager + +def find_layers(model: nn.Module, cond_fn) -> List[str]: + """Find all layers in model matching condition function. + + + + Args: + model: The model to search. + cond_fn: Callable(name, module) -> bool. + + Returns: + List of matching layer names. + """ + result = [] + for name, module in model.named_modules(): + if cond_fn(name, module): + result.append(name) + return result + + +def find_all_linears(model: nn.Module) -> List[str]: + """Find all linear layers suitable for LoRA in a Megatron model. + + + + Args: + model: The Megatron model. + + Returns: + List of layer names suitable for LoRA. + """ + def _cond(name: str, module: nn.Module) -> bool: + if name == 'output_layer': + return False + if isinstance(module, (TELinear, TELayerNormColumnParallelLinear, TEGroupedLinear, nn.Linear)): + return True + return False + + return find_layers(model, _cond) + + +def find_router(model: nn.Module) -> List[str]: + """Find all MoE router layers in a Megatron model. + + + + Args: + model: The Megatron model. + + Returns: + List of router layer names. + """ + return find_layers(model, lambda name, module: isinstance(module, TopKRouter)) + + +def find_embedding(model: nn.Module) -> List[str]: + """Find all embedding layers in a Megatron model. + + + + Args: + model: The Megatron model. + + Returns: + List of embedding layer names. + """ + return find_layers(model, lambda name, module: isinstance(module, LanguageModelEmbedding)) + + +def get_target_modules(model: nn.Module, target_modules: List[str]) -> List[str]: + """Expand target module specifications to actual module names. + + + + Args: + model: The Megatron model. + target_modules: List of target module specs, may include 'all-linear', etc. + + Returns: + Expanded list of target module names. + """ + result = target_modules.copy() + if 'all-linear' in result: + result.remove('all-linear') + result += find_all_linears(model) + if 'all-embedding' in result: + result.remove('all-embedding') + result += find_embedding(model) + if 'all-router' in result: + result.remove('all-router') + result += find_router(model) + return list(set(result)) + + +def set_linear_is_expert(model: nn.Module): + """Mark expert linear layers in MoE models. + + + + Args: + model: The Megatron model. + """ + for name, module in model.named_modules(): + if '.local_experts.' in name and isinstance( + module, (TELinear, TELayerNormColumnParallelLinear) + ): + module.is_expert = True + elif isinstance(module, TEGroupedLinear): + module.is_expert = True + + +def deep_getattr(obj: Any, attr: str, default: Any = None) -> Any: + """Get nested attribute using dot notation. + + Args: + obj: The object. + attr: Dot-separated attribute path. + default: Default value if attribute not found. + + Returns: + The attribute value or default. + """ + try: + for a in attr.split('.'): + obj = getattr(obj, a) + return obj + except AttributeError: + return default + + +# ============================================================================= + +# ============================================================================= +def _convert_hf_config(config, _internal_call: bool = False) -> Dict[str, Any]: + """Convert HuggingFace config to Megatron config dict. + + + + Args: + config: HuggingFace model config. + _internal_call: Internal flag for recursion. + + Returns: + Megatron-compatible config dict. + """ + megatron_config = {} + for k, hf_keys in CONFIG_MAPPING.items(): + for hf_k in hf_keys: + if hasattr(config, hf_k): + hf_v = getattr(config, hf_k) + if hf_v is None: + continue + if k == 'rotary_base': + megatron_config[k] = int(hf_v) + elif k in {'untie_embeddings_and_output_weights', 'disable_bias_linear', 'moe_router_pre_softmax'}: + megatron_config[k] = not hf_v + elif k == 'swiglu': + if hf_v == 'silu': + megatron_config[k] = True + else: + if k == 'kv_lora_rank': + megatron_config['multi_latent_attention'] = True + elif k == 'architectures': + if _internal_call: + k = 'llm_architectures' + megatron_config[k] = hf_v + break + + # Handle nested configs + for key in ['text_config', 'llm_config', 'thinker_config']: + if hasattr(config, key): + megatron_config.update(_convert_hf_config(getattr(config, key), _internal_call=True)) + + # Compat llama3 rope scaling + if getattr(config, 'rope_scaling', None) is not None: + if isinstance(config.rope_scaling, int): + megatron_config['rope_scaling'] = {'factor': config.rope_scaling, 'type': 'linear'} + elif isinstance(config.rope_scaling, dict): + megatron_config['rope_scaling'] = config.rope_scaling + + return megatron_config + + +def convert_hf_config(config) -> Dict[str, Any]: + """Convert HuggingFace config to Megatron-compatible config. + + + + Args: + config: HuggingFace model config. + + Returns: + Megatron-compatible config dict. + """ + res = _convert_hf_config(config) + + # Process architectures + architectures = res.get('architectures') + if isinstance(architectures, list) and architectures: + architectures = architectures[0] + res['architectures'] = architectures + + llm_architectures = res.get('llm_architectures') or architectures + if isinstance(llm_architectures, list) and llm_architectures: + llm_architectures = llm_architectures[0] + res['llm_architectures'] = llm_architectures + + # Process MoE settings + first_k_dense_replace = res.pop('first_k_dense_replace', None) + n_shared_experts = res.pop('n_shared_experts', None) + + # ==== Qwen2/Qwen2.5 Model specific settings ==== + if llm_architectures == 'Qwen2ForCausalLM': + # Qwen2/Qwen2.5 uses bias=True for Q, K, V projections (hardcoded in transformers) + # but the config doesn't have 'attention_bias' field + if 'add_qkv_bias' not in res: + res['add_qkv_bias'] = True + res['swiglu'] = True + + # ==== Qwen3 Dense Model specific settings ==== + if llm_architectures == 'Qwen3ForCausalLM': + res['qk_layernorm'] = True + # Qwen3 uses SwiGLU activation + res['swiglu'] = True + # Qwen3 typically doesn't use bias in linear layers + res['disable_bias_linear'] = True + + # ==== Qwen3 MoE Model specific settings ==== + if llm_architectures == 'Qwen3MoeForCausalLM': + res['qk_layernorm'] = True + res['swiglu'] = True + res['disable_bias_linear'] = True + # Qwen3 MoE uses shared expert gate + res['use_shared_expert_gate'] = True + # Remove ffn_hidden_size as MoE uses moe_ffn_hidden_size + res.pop('ffn_hidden_size', None) + + # DeepSeek models + if llm_architectures in {'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM'}: + res['qk_layernorm'] = True + res['moe_router_load_balancing_type'] = 'seq_aux_loss' + res.pop('num_query_groups', None) + + # Handle rope scaling + rope_scaling = res.get('rope_scaling') or {} + if 'partial_rotary_factor' not in res and 'partial_rotary_factor' in rope_scaling: + res['partial_rotary_factor'] = rope_scaling['partial_rotary_factor'] + if rope_scaling.get('mrope_section') is not None: + res['position_embedding_type'] = 'mrope' + res['mrope_section'] = rope_scaling['mrope_section'] + + # MoE layer frequency + if first_k_dense_replace is not None: + res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}' + if res.get('moe_router_score_function', 'softmax') == 'sigmoid' and 'moe_router_enable_expert_bias' not in res: + res['moe_router_enable_expert_bias'] = True + if n_shared_experts is not None and 'moe_shared_expert_intermediate_size' not in res: + res['moe_shared_expert_intermediate_size'] = n_shared_experts * res.get('moe_ffn_hidden_size', res.get('ffn_hidden_size', 0)) + + return res + + +@contextmanager +def patch_deepcopy(): + """Context manager to handle tp_group in deepcopy operations. + + + + WHY THIS IS NECESSARY: + ---------------------- + Megatron-Core's TransformerEngine linear layers (TELinear, TEColumnParallelLinear, etc.) + store a reference to their tensor parallel process group in the `tp_group` attribute. + + When PEFT's get_peft_model() is called, it internally uses copy.deepcopy() to create + copies of certain modules. However, torch.distributed.ProcessGroup objects cannot be + pickled or deepcopied because: + + 1. ProcessGroup objects contain native CUDA/NCCL handles that are process-specific + 2. These handles cannot be serialized and recreated in a different memory context + 3. Attempting to deepcopy them raises: "RuntimeError: Cannot pickle ProcessGroup" + + This patch temporarily sets tp_group to None during deepcopy, then restores it + after the copy is complete. This allows PEFT to work with Megatron modules while + preserving the correct process group references. + + USAGE: + ------ + ```python + with patch_deepcopy(): + model = get_peft_model(megatron_model, lora_config) + ``` + + Without this patch, the above code would fail with a pickling error. + """ + import copy + _origin_deepcopy = copy.deepcopy + + def new_deepcopy(x, *args, **kwargs): + if getattr(x, 'tp_group', None) is not None: + origin_tp_group = x.tp_group + x.tp_group = None + res = _origin_deepcopy(x, *args, **kwargs) + x.tp_group = origin_tp_group + res.tp_group = origin_tp_group + return res + else: + return _origin_deepcopy(x, *args, **kwargs) + + copy.deepcopy = new_deepcopy + try: + yield + finally: + copy.deepcopy = _origin_deepcopy + + +# ============================================================================= + +# ============================================================================= +def tuners_sharded_state_dict( + module: nn.Module, + prefix: str = '', + sharded_offsets: Tuple[Tuple[int, int, int]] = (), + metadata: Optional[dict] = None, +) -> Dict[str, Any]: + """Generate sharded state dict for PEFT tuners. + + + + Args: + module: The module to generate state dict for. + prefix: Key prefix. + sharded_offsets: Sharding offsets for distributed checkpointing. + metadata: Additional metadata. + + Returns: + Sharded state dictionary. + """ + sharded_state_dict = {} + # Save parameters + module._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, prefix, sharded_offsets=sharded_offsets + ) + # Recurse into submodules + for name, child in module.named_children(): + if 'Dict' in child.__class__.__name__: + modules = child.named_children() + else: + modules = [(None, child)] + for n, m in modules: + _prefix = f'{prefix}{name}.' if n is None else f'{prefix}{name}.{n}.' + sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata)) + return sharded_state_dict + + +def prepare_mcore_model( + model: nn.Module, + train_type: str = 'lora', + lora_config: Optional[Dict[str, Any]] = None, + freeze_parameters: Optional[List[str]] = None, + tenant_id: Optional[str] = None, +) -> nn.Module: + """Prepare Megatron-Core model for training. + + Args: + model: The Megatron model. + train_type: Training type ('full' or 'lora'). + lora_config: LoRA configuration dict. + freeze_parameters: List of parameter names to freeze. + tenant_id: Optional tenant ID for multi-tenant isolation. + + Returns: + Prepared model. + """ + # Set up tenant context if provided + context = contextmanager(lambda: (yield))() + if tenant_id is not None: + manager = get_tenant_manager() + context = manager.tenant_context(tenant_id) + + with context: + if train_type == 'full': + if freeze_parameters: + for name, param in model.named_parameters(): + if any(fp in name for fp in freeze_parameters): + param.requires_grad = False + elif train_type == 'lora': + set_linear_is_expert(model) + if lora_config is not None: + model = prepare_lora_model(model, lora_config) + return model + + +def prepare_lora_model( + model: nn.Module, + lora_config: Dict[str, Any], +) -> nn.Module: + """Add LoRA adapters to Megatron model. + + Args: + model: The Megatron model. + lora_config: LoRA configuration dict with keys: + - r: LoRA rank + - lora_alpha: LoRA alpha + - lora_dropout: Dropout rate + - target_modules: Target module names + - use_rslora: Use rank-stabilized LoRA + + Returns: + Model with LoRA adapters. + """ + set_linear_is_expert(model) + + target_modules = get_target_modules(model, lora_config.get('target_modules', ['all-linear'])) + + peft_config = LoraConfig( + task_type='CAUSAL_LM', + r=lora_config.get('r', 8), + lora_alpha=lora_config.get('lora_alpha', 32), + lora_dropout=lora_config.get('lora_dropout', 0.0), + target_modules=target_modules, + bias=lora_config.get('bias', 'none'), + use_rslora=lora_config.get('use_rslora', False), + ) + + with patch_deepcopy(): + model = get_peft_model(model, peft_config) + + return model + + +# ============================================================================= + +# ============================================================================= +def get_local_layer_specs(config, layer_specs: List, vp_stage: Optional[int] = None): + """Get local layer specifications for current pipeline rank. + + + + Args: + config: Megatron transformer config. + layer_specs: Full list of layer specifications. + vp_stage: Virtual pipeline stage index. + + Returns: + Local layer specifications for this rank. + """ + kwargs = {'vp_stage': vp_stage} if mcore_013 else {} + num_layers_to_build = get_num_layers_to_build(config, **kwargs) + + if getattr(config, 'pipeline_model_parallel_layout', None) is not None: + from megatron.core.transformer.enums import LayerType + local_layer_specs = [ + layer_specs[layer_id] for layer_id in config.pipeline_model_parallel_layout.get_layer_id_list( + layer_type=LayerType.decoder, **kwargs) + ] + else: + offset = get_transformer_layer_offset(config, **kwargs) + local_layer_specs = layer_specs[offset:offset + num_layers_to_build] + return local_layer_specs + + +def get_padding_to( + tensor_model_parallel_size: int = 1, + context_parallel_size: int = 1, + sequence_parallel: bool = False, + fp8_format: Optional[str] = None, + fp8_recipe: Optional[str] = None, + attention_backend: Optional[str] = None, +) -> Optional[int]: + """Get padding size for sequence length. + + Args: + tensor_model_parallel_size: TP size. + context_parallel_size: CP size. + sequence_parallel: Whether sequence parallel is enabled. + fp8_format: FP8 format if used. + fp8_recipe: FP8 recipe if used. + attention_backend: Attention backend type. + + Returns: + Padding size or None. + """ + padding_to = None + if tensor_model_parallel_size > 1 and sequence_parallel: + padding_to = tensor_model_parallel_size + if context_parallel_size > 1: + padding_to = (padding_to or 1) * context_parallel_size + origin_padding_to = padding_to + + if fp8_recipe == 'blockwise': + padding_to = (padding_to or 1) * 128 + elif fp8_format is not None: + padding_to = max((padding_to or 1) * 8, 16) + + if attention_backend == 'fused': + padding_to = max(padding_to or 1, ((origin_padding_to) or 1) * 64) + + return padding_to + + +# ============================================================================= + +# ============================================================================= +def forward_step_helper(model: nn.Module, inputs: Dict[str, Any], config) -> Optional[torch.Tensor]: + """Helper for pipeline parallel forward step. + + Handles communication between pipeline stages. + + Args: + model: The model. + inputs: Input dict with position_ids, etc. + config: Configuration with parallel settings. + + Returns: + Output tensor for last stage, None otherwise. + """ + from megatron.core.inference.communication_utils import recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + if mpu.is_pipeline_first_stage(): + # Get micro_batch_size from input tensor, not config + # For padding_free (qkv_format 'thd'), use 1 + micro_batch_size = 1 + if not getattr(config, 'padding_free', False): + # Infer batch size from input_ids or position_ids + if 'input_ids' in inputs: + micro_batch_size = inputs['input_ids'].shape[0] + elif 'position_ids' in inputs: + micro_batch_size = inputs['position_ids'].shape[0] + else: + micro_batch_size = 1 + seq_length = inputs['position_ids'].shape[-1] + if config.sequence_parallel: + seq_length //= mpu.get_tensor_model_parallel_world_size() + recv_shape_buffer = torch.tensor( + [seq_length, micro_batch_size, config.hidden_size], + device=torch.cuda.current_device(), + dtype=torch.int64 + ) + else: + recv_shape_buffer = torch.empty((3,), device=torch.cuda.current_device(), dtype=torch.int64) + recv_from_prev_pipeline_rank_(recv_shape_buffer) + + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(recv_shape_buffer) + shape = recv_shape_buffer.tolist() + + if not mpu.is_pipeline_first_stage(): + dtype = config.params_dtype + recv_buffer = torch.empty(shape, device=torch.cuda.current_device(), dtype=dtype) + recv_from_prev_pipeline_rank_(recv_buffer) + model.set_input_tensor(recv_buffer) + + output_tensor = model(**inputs) + + if not mpu.is_pipeline_last_stage(): + send_to_next_pipeline_rank(output_tensor) + output_tensor = None + + return output_tensor + + +class MegatronTrainerState: + """Lightweight trainer state for Megatron training. + + Provides compatibility with transformers TrainerState interface. + + Attributes: + global_step: The current training step. + max_steps: The total number of training steps. + """ + + def __init__(self, global_step: int = 0, max_steps: int = 0): + self.global_step = global_step + self.max_steps = max_steps + + def update(self, global_step: Optional[int] = None, max_steps: Optional[int] = None): + if global_step is not None: + self.global_step = global_step + if max_steps is not None: + self.max_steps = max_steps + + def __repr__(self) -> str: + return f'MegatronTrainerState(global_step={self.global_step}, max_steps={self.max_steps})' + + +def get_model_parameter_info(model: nn.Module) -> Dict[str, Any]: + """Get parameter count information for a model. + + Args: + model: The model. + + Returns: + Dict with total_params, trainable_params, frozen_params. + """ + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + frozen_params = total_params - trainable_params + + return { + 'total_params': total_params, + 'trainable_params': trainable_params, + 'frozen_params': frozen_params, + 'trainable_ratio': trainable_params / total_params if total_params > 0 else 0, + } diff --git a/src/twinkle/model/__init__.py b/src/twinkle/model/__init__.py index 0bdf65ef..b6bd403d 100644 --- a/src/twinkle/model/__init__.py +++ b/src/twinkle/model/__init__.py @@ -2,3 +2,4 @@ from .transformers import TransformersModel from .base import TwinkleModel from .multi_lora_transformers import MultiLoraTransformersModel +from .megatron import MegatronModel diff --git a/src/twinkle/model/megatron.py b/src/twinkle/model/megatron.py new file mode 100644 index 00000000..be2926cb --- /dev/null +++ b/src/twinkle/model/megatron.py @@ -0,0 +1,1551 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron-Core model wrapper for twinkle training framework.""" +import contextlib +import json +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Type, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler + +import twinkle +from twinkle import DeviceMesh, remote_class, remote_function, template +from twinkle.data_format import InputFeature, Trajectory +from twinkle.hub import HubOperation +from twinkle.loss import Loss, MegatronCrossEntropyLoss +from twinkle.processor import InputProcessor +from twinkle.template import Template +from twinkle.utils.plugin import Plugin + +from .base import TwinkleModel +from .strategy import MegatronStrategy + +try: + import megatron.core + from megatron.core import parallel_state as mpu + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from packaging import version + MEGATRON_AVAILABLE = True + mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') +except ImportError: + MEGATRON_AVAILABLE = False + mcore_013 = False + + +@dataclass +class MegatronOptimizerGroup: + """Optimizer group for Megatron training. + + Similar to OptimizerGroup but adapted for Megatron's distributed training. + """ + adapter_name: str = None + adapter_config: Any = None + optimizer: Optimizer = None + lr_scheduler: LRScheduler = None + inputs: Dict[str, Any] = None + outputs: Dict[str, Any] = None + loss_instance: Loss = None + loss_value: Any = None + template: Template = None + processor: InputProcessor = None + gradient_accumulation_steps: int = 1 + cur_step: int = 0 + dp_group = None + # Megatron optimizer specific fields + is_megatron_optimizer: bool = False + _last_grad_norm: float = 0.0 + _last_step_success: bool = True + + def do_grad_sync(self, + gradient_accumulation_steps: Optional[int] = None + ) -> bool: + """Check if gradient synchronization should happen.""" + if gradient_accumulation_steps is None: + gradient_accumulation_steps = self.gradient_accumulation_steps + return self.cur_step % gradient_accumulation_steps == 0 and self.cur_step > 0 + + +_default_adapter_name = '' + + +def check_megatron_available(): + """Check if Megatron-Core is available.""" + if not MEGATRON_AVAILABLE: + raise ImportError( + 'Megatron-Core is not installed. Please install it with: ' + 'pip install megatron-core') + + +@remote_class(execute='all') +class MegatronModel(TwinkleModel, nn.Module): + """Megatron-Core model wrapper for twinkle training framework. + + Note: Uses execute='all' to create workers on all ranks, which is required + for Megatron's TP/DP parallelism where all ranks must participate in + collective operations like gradient all-reduce. + + This class provides a similar API to TransformersModel but uses Megatron-Core + as the training backend, supporting TP/PP/CP/EP parallelism. + + Args: + pretrained_model_name_or_path: HuggingFace model path or ID. + device_mesh: Twinkle DeviceMesh for distributed training. + tensor_model_parallel_size: Tensor parallel size. + pipeline_model_parallel_size: Pipeline parallel size. + context_parallel_size: Context parallel size. + expert_model_parallel_size: Expert parallel size. + sequence_parallel: Enable sequence parallelism. + mixed_precision: Mixed precision mode. + use_distributed_optimizer: Use Megatron's distributed optimizer. + **kwargs: Additional arguments passed to model initialization. + """ + def __init__( + self, + pretrained_model_name_or_path: str, + device_mesh: Optional[DeviceMesh] = None, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + sequence_parallel: bool = False, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + use_distributed_optimizer: bool = True, + load_weights: bool = True, + use_megatron_bridge: + bool = True, # Use bridge-based initialization (recommended) + recompute_granularity: Optional[ + str] = 'selective', # Activation checkpointing + recompute_modules: Optional[list] = None, # Modules to recompute + **kwargs, + ): + check_megatron_available() + nn.Module.__init__(self) + + self.model_id = pretrained_model_name_or_path + self.device_mesh = device_mesh + self.mixed_precision = mixed_precision + self.use_megatron_bridge = use_megatron_bridge + self.recompute_granularity = recompute_granularity + self.recompute_modules = recompute_modules + + # Load HuggingFace config first + model_path = HubOperation.download_model(pretrained_model_name_or_path) + self._load_hf_config(model_path) + + # Store model_path for later use + self._model_path = model_path + + # Create Megatron strategy + self.strategy = MegatronStrategy( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + context_parallel_size=context_parallel_size, + expert_model_parallel_size=expert_model_parallel_size, + sequence_parallel=sequence_parallel, + use_distributed_optimizer=use_distributed_optimizer, + mixed_precision=mixed_precision, + ) + + # Initialize parallel state (skip if using bridge init, as it handles this) + if not use_megatron_bridge: + self.strategy.initialize() + + # Create Megatron model + self.model = self._create_megatron_model(model_path, load_weights, + **kwargs) + + self._model_wrapped = False + # This correctly handles vocab sharding in Tensor Parallelism + self.optimizer_group: Dict[str, MegatronOptimizerGroup] = { + _default_adapter_name: + MegatronOptimizerGroup(loss_instance=MegatronCrossEntropyLoss()) + } + + def _load_hf_config(self, model_path: str): + """Load HuggingFace model config.""" + from transformers import AutoConfig + self.hf_config = AutoConfig.from_pretrained(model_path) + + def _create_megatron_model( + self, + model_path: str, + load_weights: bool = True, + **kwargs, + ) -> nn.Module: + """Create Megatron model from HuggingFace checkpoint. + + Args: + model_path: Path to HuggingFace model. + load_weights: Whether to load weights. + **kwargs: Additional arguments. + + Returns: + Megatron model on GPU. + """ + params_dtype = torch.bfloat16 + if self.mixed_precision == 'fp16': + params_dtype = torch.float16 + elif self.mixed_precision == 'no': + params_dtype = torch.float32 + + if self.use_megatron_bridge: + # Use bridge-based initialization (recommended) + # This ensures all patches are applied and config is correctly generated + return self._create_megatron_model_with_bridge( + model_path, load_weights, params_dtype, **kwargs) + else: + # Use twinkle's native initialization + return self._create_megatron_model_native(model_path, load_weights, + params_dtype, **kwargs) + + def _create_megatron_model_with_bridge( + self, + model_path: str, + load_weights: bool, + params_dtype: torch.dtype, + **kwargs, + ) -> nn.Module: + """Create Megatron model using bridge-based initialization flow. + + This approach uses TwinkleBridgeInitializer for independent initialization + It includes: + - Proper config conversion from HuggingFace to Megatron format + - Correct Megatron initialization (initialize_megatron) + - Correct model creation + - Weight loading with TwinkleGPTBridge + + Args: + model_path: Path to HuggingFace model. + load_weights: Whether to load weights. + params_dtype: Parameter dtype. + **kwargs: Additional arguments. + + Returns: + Megatron model on GPU. + """ + from twinkle.megatron.model.bridge import TwinkleBridgeInitializer + + # Create bridge-based initializer + self._bridge_initializer = TwinkleBridgeInitializer( + tp_size=self.strategy.tp_size, + pp_size=self.strategy.pp_size, + cp_size=self.strategy.cp_size, + ep_size=self.strategy.ep_size, + params_dtype=params_dtype, + use_cpu_initialization=False, + attention_backend='flash', # Use flash for training performance + sequence_parallel=self.strategy.sequence_parallel, + recompute_granularity=self.recompute_granularity, + recompute_modules=self.recompute_modules, + recompute_method=getattr(self, 'recompute_method', None), + recompute_num_layers=getattr(self, 'recompute_num_layers', None), + ) + + # Create model (this calls initialize_megatron internally) + model = self._bridge_initializer.create_model( + model_path, load_weights=load_weights) + + # Update strategy state since bridge has initialized Megatron + self.strategy._initialized = True + self.strategy._parallel_state = mpu + + # Save transformer config for DDP wrapping + self._transformer_config = getattr(self._bridge_initializer, + '_transformer_config', None) + + # Move to GPU + model = self._move_model_to_gpu(model) + + return model + + def _create_megatron_model_native( + self, + model_path: str, + load_weights: bool, + params_dtype: torch.dtype, + **kwargs, + ) -> nn.Module: + """Create Megatron model using twinkle's native initialization. + + This is the fallback method when bridge is not available. + + Args: + model_path: Path to HuggingFace model. + load_weights: Whether to load weights. + params_dtype: Parameter dtype. + **kwargs: Additional arguments. + + Returns: + Megatron model on GPU. + """ + from twinkle.megatron.model.initializer import MegatronModelInitializer + + initializer = MegatronModelInitializer( + tp_size=self.strategy.tp_size, + pp_size=self.strategy.pp_size, + cp_size=self.strategy.cp_size, + ep_size=self.strategy.ep_size, + sequence_parallel=self.strategy.sequence_parallel, + params_dtype=params_dtype, + ) + + # Create model + model = initializer.create_gpt_model(self.hf_config, **kwargs) + + # Load weights + if load_weights: + initializer.load_from_hf(model, model_path, self.hf_config) + + model = self._move_model_to_gpu(model) + + return model + + def _move_model_to_gpu(self, model: nn.Module) -> nn.Module: + """Move model to correct GPU device. + + This method handles moving parameters, buffers, and any cached tensors + (like RoPE embeddings) to the correct device for distributed training. + """ + # Determine the target device based on local rank + local_rank = dist.get_rank() % torch.cuda.device_count( + ) if dist.is_initialized() else 0 + device = torch.device(f'cuda:{local_rank}') + + # Set CUDA device explicitly + torch.cuda.set_device(local_rank) + + # Move all parameters and buffers to GPU + model = model.to(device) + + # Force synchronize to ensure all transfers complete + if torch.cuda.is_available(): + torch.cuda.synchronize(device) + + return model + + def _lazy_wrap_model(self): + """Lazily wrap model with distributed wrapper. + + Note: This should only be called after prepare_training() has been + executed on all workers. Direct calls from forward() may cause + deadlocks if not all DP ranks are participating. + """ + if not self._model_wrapped: + # Find an optimizer from any adapter group (prefer default, then first available) + optimizer = None + optimizer_adapter = None + + if _default_adapter_name in self.optimizer_group: + optimizer = self.optimizer_group[ + _default_adapter_name].optimizer + optimizer_adapter = _default_adapter_name + else: + for name, group in self.optimizer_group.items(): + if group.optimizer is not None: + optimizer = group.optimizer + optimizer_adapter = name + break + + if optimizer is not None: + self.model, optimizer = self.strategy.wrap_model( + self.model, optimizer) + self.optimizer_group[optimizer_adapter].optimizer = optimizer + self._model_wrapped = True + + @remote_function(dispatch='all') + def prepare_training(self, **kwargs): + """Prepare model for training. + + Note: In Ray-based Megatron training, we skip DDP wrapping to avoid + deadlocks from collective operations. Each DP replica trains independently. + This method still calls _lazy_wrap_model for any non-DDP setup needed. + """ + self._lazy_wrap_model() + + @remote_function() + def forward(self, *, inputs: Union[InputFeature, List[InputFeature], + Trajectory, List[Trajectory]], + **kwargs): + """Forward pass with Megatron model. + + Args: + inputs: Model inputs. + **kwargs: Additional arguments including adapter_name. + + Returns: + Model outputs. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + self._lazy_wrap_model() + + # Encode inputs if needed + if isinstance(inputs, dict) and 'input_ids' not in inputs: + if optimizer_config.template is not None: + inputs = optimizer_config.template.encode(inputs) + if isinstance(inputs, list) and 'input_ids' not in inputs[0]: + if optimizer_config.template is not None: + inputs = optimizer_config.template.batch_encode(inputs) + + # Process inputs + processor: InputProcessor = optimizer_config.processor + if processor is not None: + inputs: Dict[str, Any] = processor(inputs) + + labels = inputs.get('labels', None) + if 'labels' in inputs: + try: + del inputs['labels'] + except (TypeError, KeyError): + pass # Some dict-like types don't support deletion + + # Forward through model + outputs = self._forward_step(inputs) + + inputs['labels'] = labels + optimizer_config.inputs = inputs + optimizer_config.outputs = outputs + return outputs + + def _forward_step(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Execute forward step with pipeline parallelism support. + + Args: + inputs: Processed inputs. + + Returns: + Model outputs. + """ + # Handle pipeline parallelism + if self.strategy.pp_size > 1: + return self._forward_step_pipeline(inputs) + else: + return self._forward_step_simple(inputs) + + def _forward_step_simple(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Simple forward step without pipeline parallelism.""" + model = self.strategy.unwrap_model(self.model) + + # Prepare inputs for Megatron + input_ids = inputs.get('input_ids') + attention_mask = inputs.get('attention_mask') + position_ids = inputs.get('position_ids') + + # Create position_ids if not provided + if position_ids is None and input_ids is not None: + position_ids = torch.arange( + input_ids.shape[1], + device=input_ids.device, + dtype=torch.long, + ).unsqueeze(0).expand(input_ids.shape[0], -1) + + # Forward pass + outputs = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + ) + + return {'logits': outputs} + + def _forward_step_pipeline(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Forward step with pipeline parallelism. + + Note: For PP > 1, the forward pass is handled by Megatron's pipeline scheduler + in forward_backward(). This method is for simple forward-only inference. + For training, use forward_backward() which uses get_forward_backward_func(). + """ + from twinkle.megatron.utils import forward_step_helper + + model = self.strategy.unwrap_model(self.model) + + # Use pipeline forward helper + output = forward_step_helper( + model, + inputs, + model.config, + ) + + if output is not None: + return {'logits': output} + return {} + + @remote_function() + def forward_only(self, *, inputs: Union[InputFeature, List[InputFeature], + List[Trajectory]], **kwargs): + """Forward pass without gradient computation. + + Args: + inputs: Model inputs. + **kwargs: Additional arguments. + + Returns: + Model outputs. + """ + with torch.no_grad(): + return self.forward(inputs=inputs, **kwargs) + + @remote_function(collect='mean') + def calculate_loss(self, **kwargs): + """Calculate loss from forward outputs. + + Args: + **kwargs: Additional arguments including adapter_name. + + Returns: + Loss value as numpy array. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + loss_instance: Loss = optimizer_config.loss_instance + + inputs = optimizer_config.inputs + outputs = optimizer_config.outputs + + assert inputs is not None and outputs is not None, \ + 'Cannot calculate loss of empty inputs and outputs' + + loss_value = loss_instance(inputs, outputs, **kwargs) + optimizer_config.loss_value = loss_value + return loss_value.detach().cpu().float().numpy() + + @remote_function() + def backward(self, **kwargs): + """Backward pass. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + loss_value = optimizer_config.loss_value + + assert loss_value is not None, 'Do forwarding and calculating loss before backward' + + _gas = optimizer_config.gradient_accumulation_steps + if 'gradient_accumulation_steps' in kwargs: + _gas = kwargs['gradient_accumulation_steps'] + + loss_value = loss_value / _gas + loss_value.backward() + optimizer_config.cur_step += 1 + + @remote_function(dispatch='all', collect='mean', sync=True) + def forward_backward(self, + *, + inputs: Union[InputFeature, List[InputFeature], + Trajectory, List[Trajectory]], + num_microbatches: int = 1, + **kwargs): + """Combined forward and backward pass using Megatron's scheduler. + + Note: sync=True is required for Ray mode because Megatron's pipeline + parallel uses NCCL P2P communication that requires all ranks to enter + the function simultaneously. + + Always uses Megatron's get_forward_backward_func() which handles: + - Pipeline scheduling (1F1B, interleaved, or no-pipeline) + - Communication between stages (using proper process groups for multi-tenant isolation) + - Gradient accumulation across microbatches + + Args: + inputs: Model inputs. Can be: + - A single batch dict (num_microbatches=1) + - A list of batch dicts (num_microbatches=len(inputs)) + - An iterator yielding batch dicts + num_microbatches: Number of microbatches to process in one call. + If inputs is a list, this is inferred from len(inputs). + Using num_microbatches > 1 enables Megatron's native gradient + accumulation with better memory management and compute overlap. + **kwargs: Additional arguments. + + Returns: + Average loss value across all microbatches. + """ + from functools import partial + from megatron.core.pipeline_parallel import get_forward_backward_func + + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + self._lazy_wrap_model() + + # Encode inputs if needed + if isinstance(inputs, dict) and 'input_ids' not in inputs: + if optimizer_config.template is not None: + inputs = optimizer_config.template.encode(inputs) + if isinstance(inputs, list) and 'input_ids' not in inputs[0]: + if optimizer_config.template is not None: + inputs = optimizer_config.template.batch_encode(inputs) + + # Process inputs (collate list to batched dict) + processor = optimizer_config.processor + if processor is not None: + inputs = processor(inputs) + + # Store labels before removing from inputs + labels = inputs.get('labels', None) + if 'labels' in inputs: + try: + del inputs['labels'] + except (TypeError, KeyError): + pass # Some dict-like types don't support deletion + + # Move labels to GPU if needed + if labels is not None and not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, device=torch.cuda.current_device()) + elif labels is not None: + labels = labels.to(torch.cuda.current_device()) + + # Get parallelism settings for sequence padding and splitting + cp_size = self.strategy.cp_size + tp_size = self.strategy.tp_size + # Check actual sequence_parallel setting from model config + # Bridge may auto-enable sequence_parallel for MoE models + model = self.strategy.unwrap_model(self.model) + if hasattr(model, 'config') and hasattr(model.config, 'sequence_parallel'): + sequence_parallel = model.config.sequence_parallel + else: + sequence_parallel = self.strategy.sequence_parallel + cp_rank = mpu.get_context_parallel_rank() if cp_size > 1 else 0 + + # Get sequence length and batch size + original_seq_length = inputs['input_ids'].shape[1] if 'input_ids' in inputs else 1 + micro_batch_size = inputs['input_ids'].shape[0] if 'input_ids' in inputs else 1 + + # Calculate padded seq_length based on parallelism requirements + # 1. For CP > 1: seq_len must be divisible by 2 * cp_size + # 2. For sequence_parallel with TP > 1: seq_len must be divisible by tp_size + if cp_size > 1: + divisor = 2 * cp_size + elif sequence_parallel and tp_size > 1: + divisor = tp_size + else: + divisor = 1 + + if divisor > 1 and original_seq_length % divisor != 0: + seq_length = original_seq_length + (divisor - original_seq_length % divisor) + else: + seq_length = original_seq_length + + def split_tensor_for_cp(tensor, dim=-1): + """ + Split tensor along sequence dimension for Context Parallel. + + With causal masking, split into 2*CP chunks and assign alternating + chunks to balance workload across CP ranks. + For CP rank i: chunks [i, 2*CP-1-i] + """ + if tensor is None or cp_size <= 1: + return tensor + + if dim < 0: + dim = (dim + tensor.ndim) % tensor.ndim + + seq_len = tensor.shape[dim] + + # Reshape to [batch, 2*cp_size, seq_per_chunk, ...] + view_shape = list(tensor.shape) + view_shape[dim:dim + 1] = [2 * cp_size, seq_len // (2 * cp_size)] + reshaped = tensor.view(*view_shape) + + # Select chunks [cp_rank, 2*cp_size-1-cp_rank] + index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], + device='cpu', + pin_memory=True).cuda(non_blocking=True) + selected = reshaped.index_select(dim, index) + + # Reshape back: [batch, 2*seq_per_chunk, ...] + out_shape = list(tensor.shape) + out_shape[dim] = seq_len // cp_size + return selected.reshape(*out_shape) + + # Define forward step function for Megatron + # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func)) + def forward_step_func(data_iterator, model): + batch = next(data_iterator) + input_ids = batch.get('input_ids') + position_ids = batch.get('position_ids') + attention_mask = batch.get('attention_mask') + batch_labels = batch.get('labels', labels) # Use batch labels or passed labels + + # Pad sequence for parallel compatibility + # 1. For CP > 1: Megatron's RoPE requires seq_len % (2 * cp_size) == 0 + # 2. For sequence_parallel: seq_len must be divisible by TP size + if input_ids is not None: + seq_len = input_ids.shape[1] + + # Calculate required divisor based on parallelism settings + if cp_size > 1: + divisor = 2 * cp_size + elif sequence_parallel and tp_size > 1: + divisor = tp_size + else: + divisor = 1 + + if divisor > 1 and seq_len % divisor != 0: + pad_len = divisor - (seq_len % divisor) + # Pad input_ids + input_ids = torch.nn.functional.pad(input_ids, + (0, pad_len), + value=0) + # Pad labels if present + if batch_labels is not None: + batch_labels = torch.nn.functional.pad(batch_labels, + (0, pad_len), + value=-100) + # Pad attention_mask if present + if attention_mask is not None: + attention_mask = torch.nn.functional.pad( + attention_mask, (0, pad_len), value=0) + # Pad position_ids if present + if position_ids is not None: + position_ids = torch.nn.functional.pad(position_ids, + (0, pad_len), + value=0) + + # Create position_ids if not provided + if position_ids is None and input_ids is not None: + position_ids = torch.arange( + input_ids.shape[1], + device=input_ids.device, + dtype=torch.long, + ).unsqueeze(0).expand(input_ids.shape[0], -1) + + # Split tensors for Context Parallel + # Each CP rank processes a portion of the sequence + if cp_size > 1: + input_ids = split_tensor_for_cp(input_ids, dim=-1) + position_ids = split_tensor_for_cp(position_ids, dim=-1) + attention_mask = split_tensor_for_cp(attention_mask, dim=-1) + batch_labels = split_tensor_for_cp(batch_labels, dim=-1) + + # Forward pass with labels - Megatron will compute loss internally + # This uses Megatron's compute_language_model_loss which properly handles + # vocab parallel cross entropy + output_tensor = model( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=batch_labels, # Pass labels to let Megatron compute loss + ) + + # Megatron's compute_language_model_loss returns per-token loss [batch, seq] + # We need to aggregate it with loss_mask + def megatron_loss_func(labels_for_mask, cp_size, output_tensor): + # output_tensor is per-token loss [batch, seq] + # Create loss mask from labels (ignore -100) + loss_mask = (labels_for_mask != -100).float() + + # Flatten and compute mean + losses = output_tensor.float().view(-1) + loss_mask_flat = loss_mask.view(-1) + + # Compute local sum and count + local_loss_sum = torch.sum(losses * loss_mask_flat) + local_count = loss_mask_flat.sum() + + # For CP > 1, aggregate loss across CP ranks + if cp_size > 1: + # All-reduce the count across CP ranks + total_count = local_count.clone() + torch.distributed.all_reduce( + total_count, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group() + ) + + # All-reduce the loss sum + total_loss_sum = local_loss_sum.clone() + torch.distributed.all_reduce( + total_loss_sum, + op=torch.distributed.ReduceOp.SUM, + group=mpu.get_context_parallel_group() + ) + + # Return global mean, divided by cp_size to counteract Megatron's multiplication + loss = (total_loss_sum / total_count.clamp(min=1)) / cp_size + else: + loss = local_loss_sum / local_count.clamp(min=1) + + return loss, {'loss': loss.detach()} + + return output_tensor, partial(megatron_loss_func, batch_labels, + cp_size) + + # Get Megatron's forward-backward function + # This automatically selects the right scheduler based on PP config: + # - PP > 1: forward_backward_pipelining_without_interleaving (or with interleaving if VPP) + # - PP = 1: forward_backward_no_pipelining + forward_backward_func = get_forward_backward_func() + + # Create single-item iterator + data_iter = iter([inputs]) + + # Run forward-backward with Megatron's scheduler + # Megatron handles all communication internally using proper process groups + losses = forward_backward_func( + forward_step_func=forward_step_func, + data_iterator=data_iter, + model=[self.model], + num_microbatches=1, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + forward_only=False, + ) + + # Extract loss from results (only last PP stage returns non-empty) + loss = 0.0 + + if losses: + for loss_dict in losses: + if isinstance(loss_dict, dict) and 'loss' in loss_dict: + loss = loss_dict['loss'] + break + elif isinstance(loss_dict, torch.Tensor): + loss = loss_dict + break + + # For PP > 1, broadcast loss from last PP stage to all ranks + # Note: mpu is imported at module level, no need to reimport + if mpu.get_pipeline_model_parallel_world_size() > 1: + if isinstance(loss, torch.Tensor): + loss_tensor = loss.detach().clone() + else: + loss_tensor = torch.tensor(loss, + dtype=torch.float32, + device=torch.cuda.current_device()) + + # Broadcast from last PP stage (rank with pipeline_model_parallel_rank == pp_size - 1) + src_rank = mpu.get_pipeline_model_parallel_last_rank() + pp_group = mpu.get_pipeline_model_parallel_group() + + torch.distributed.broadcast(loss_tensor, + src=src_rank, + group=pp_group) + + loss = loss_tensor.item() + + optimizer_config.cur_step += 1 + + # Critical: Synchronize all DP replicas before returning + # This ensures all DP replicas complete the same training step before + # moving to the next batch, preventing P2P communication deadlocks + dp_world_size = mpu.get_data_parallel_world_size() + if dp_world_size > 1: + # Use barrier on DP+CP group to synchronize all replicas + dp_cp_group = mpu.get_data_parallel_group(with_context_parallel=True) + dist.barrier(group=dp_cp_group) + + if isinstance(loss, torch.Tensor): + return loss.detach().cpu().float().numpy() + return float(loss) + + @remote_function(dispatch='all') + def clip_grad_norm(self, + max_grad_norm: float = 1.0, + norm_type: int = 2, + **kwargs): + """Clip gradient norm. + + Args: + max_grad_norm: Maximum gradient norm. + norm_type: Type of norm to use. + **kwargs: Additional arguments. + + Returns: + Total norm of gradients. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + # Check if using Megatron optimizer (handles clip_grad internally) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', + False) + if is_megatron_opt: + # Megatron optimizer handles gradient clipping in step() + # Return the grad_norm from last step if available + return getattr(optimizer_config, '_last_grad_norm', 0.0) + + parameters = self._get_trainable_parameters(adapter_name).values() + + return torch.nn.utils.clip_grad_norm_( + parameters, max_grad_norm, + norm_type=norm_type).detach().cpu().numpy() + + @remote_function(dispatch='all') + def step(self, **kwargs): + """Optimizer step. + + For DDP-wrapped models: + - Gradients are synchronized automatically during backward via DDP + + For non-DDP models (e.g., PEFT/LoRA): + - Gradients are NOT synchronized across DP ranks + - Each DP replica trains independently with different data + - This is a common pattern for PEFT training where the overhead of + gradient averaging is not worth the benefit + + Note: Uses dispatch='all' to ensure all workers execute this method. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): + return + + # For DDP-wrapped models, gradients are already synchronized during backward + if self._is_model_ddp_wrapped(): + # For Megatron DDP, ensure gradient buffers are finalized + if hasattr(self.model, 'finish_grad_sync'): + self.model.finish_grad_sync() + # For non-DDP models (e.g., PEFT), we skip gradient synchronization + # Each DP replica trains independently, which is acceptable for PEFT + + optimizer = optimizer_config.optimizer + assert optimizer is not None, 'Set optimizer correctly before stepping' + + # Check if using Megatron optimizer (has different step() signature) + is_megatron_opt = getattr(optimizer_config, 'is_megatron_optimizer', + False) + if is_megatron_opt: + # Megatron optimizer step() returns (success, grad_norm, num_zeros) + success, grad_norm, num_zeros = optimizer.step() + # Store grad_norm for later retrieval + optimizer_config._last_grad_norm = grad_norm if grad_norm is not None else 0.0 + optimizer_config._last_step_success = success + else: + optimizer.step(**kwargs) + + def _is_model_ddp_wrapped(self) -> bool: + """Check if model is wrapped with DDP. + + Returns: + True if model is wrapped with DDP (either Megatron DDP, LoRA DDP, or PyTorch DDP). + """ + from torch.nn.parallel import DistributedDataParallel as TorchDDP + return isinstance(self.model, (MegatronDDP, TorchDDP)) + + def _get_unwrapped_model(self) -> nn.Module: + """Get the unwrapped model. + + Returns: + The base model without DDP wrapper. + """ + return self.strategy.unwrap_model(self.model) + + @remote_function(dispatch='all') + def zero_grad(self, **kwargs): + """Zero gradients. + + For DDP-wrapped models, also zeros the DDP gradient buffers. + + Note: For DDP-wrapped models, zero_grad_buffer() is always called + because it's essential for the next training iteration. The + do_grad_sync check only affects the optimizer.zero_grad() call. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + # For DDP-wrapped models, ALWAYS zero the gradient buffer + # This is essential because Megatron's forward_backward_func uses + # the buffer's state to track gradient accumulation + if self._is_model_ddp_wrapped() and hasattr(self.model, + 'zero_grad_buffer'): + self.model.zero_grad_buffer() + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): + return + + optimizer = optimizer_config.optimizer + if optimizer is not None: + # Clear set_to_none for better compatibility + optimizer.zero_grad(set_to_none=True) + + @remote_function() + def lr_step(self, **kwargs): + """Learning rate scheduler step. + + Args: + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if not optimizer_config.do_grad_sync( + kwargs.get('gradient_accumulation_steps')): + return + + lr_scheduler = optimizer_config.lr_scheduler + if lr_scheduler is not None: + lr_scheduler.step(**kwargs) + + @remote_function(dispatch='all') + def set_loss(self, loss_cls: Union[Type[Loss], str], **kwargs): + """Set loss function. + + NOTE: For MegatronModel, the loss is computed internally by Megatron's + GPTModel when labels are passed. This method is kept for API compatibility + but the provided loss_cls is NOT used during forward_backward. + + Megatron internally uses vocab_parallel_cross_entropy which correctly + handles tensor parallelism. This design ensures Loss classes don't need + to be aware of the training backend (Megatron vs Transformers). + + Args: + loss_cls: Loss class or string name (not used for Megatron). + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(loss_cls, str): + if hasattr(twinkle.loss, loss_cls): + loss_cls = getattr(twinkle.loss, loss_cls) + else: + loss_cls = Plugin.load_plugin(loss_cls, Loss) + # Keep for API compatibility, but not used in forward_backward + optimizer_config.loss_instance = loss_cls() + + @remote_function(dispatch='all') + def set_optimizer(self, optimizer_cls: Union[Type[Optimizer], str], + **kwargs): + """Set optimizer. + + Args: + optimizer_cls: Optimizer class or string name. + - Standard PyTorch optimizers: 'AdamW', 'Adam', 'SGD', etc. + - 'MegatronDistributed': Use Megatron's distributed optimizer + **kwargs: Additional arguments. + - For standard optimizers: lr, weight_decay, etc. + - For MegatronDistributed: use_distributed_optimizer, clip_grad, etc. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + # Check if requesting Megatron distributed optimizer + if optimizer_cls == 'MegatronDistributed' or kwargs.pop( + 'use_megatron_optimizer', False): + optimizer_config.optimizer = self._create_megatron_optimizer( + **kwargs) + optimizer_config.is_megatron_optimizer = True + return + + if isinstance(optimizer_cls, str): + if hasattr(torch.optim, optimizer_cls): + optimizer_cls = getattr(torch.optim, optimizer_cls) + else: + optimizer_cls = Plugin.load_plugin(optimizer_cls, Optimizer) + + optimizer_config.optimizer = optimizer_cls( + self._get_trainable_parameters(adapter_name).values(), **kwargs) + optimizer_config.is_megatron_optimizer = False + + def _create_megatron_optimizer(self, **kwargs): + """Create Megatron distributed optimizer. + + This provides significant memory savings for large models by sharding + optimizer states across DP replicas. + + Args: + **kwargs: Optimizer configuration options. + - lr: Learning rate (default: 1e-4) + - weight_decay: Weight decay (default: 0.0) + - use_distributed_optimizer: Shard optimizer states (default: True) + - clip_grad: Gradient clipping threshold (default: 1.0) + - bf16: Use bf16 training (default: True) + - adam_beta1, adam_beta2, adam_eps: Adam parameters + + Returns: + MegatronOptimizer instance. + """ + from megatron.core.optimizer import get_megatron_optimizer, OptimizerConfig + + # Build optimizer config + lr = kwargs.get('lr', 1e-4) + use_distributed_optimizer = kwargs.get('use_distributed_optimizer', + True) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=lr, + min_lr=kwargs.get('min_lr', 0.0), + weight_decay=kwargs.get('weight_decay', 0.0), + adam_beta1=kwargs.get('adam_beta1', 0.9), + adam_beta2=kwargs.get('adam_beta2', 0.999), + adam_eps=kwargs.get('adam_eps', 1e-8), + clip_grad=kwargs.get('clip_grad', 1.0), + bf16=kwargs.get('bf16', True), + use_distributed_optimizer=use_distributed_optimizer, + overlap_param_gather=kwargs.get('overlap_param_gather', False), + log_num_zeros_in_grad=kwargs.get('log_num_zeros_in_grad', False), + ) + + # For PEFT models, we need to handle the case where model is not DDP-wrapped + # We create a temporary wrapper to satisfy Megatron's optimizer requirements + model_chunks = [self.model] + + # Check if model has ddp_config (required for distributed optimizer) + if not hasattr(self.model, 'ddp_config') and use_distributed_optimizer: + # For PEFT models without DDP, fall back to non-distributed optimizer + # but still use Megatron's optimized implementation + opt_config.use_distributed_optimizer = False + if mpu.get_data_parallel_rank() == 0: + print( + 'Note: Falling back to non-distributed optimizer for PEFT model. ' + 'For distributed optimizer, wrap model with MegatronDDP.') + + try: + optimizer = get_megatron_optimizer( + config=opt_config, + model_chunks=model_chunks, + ) + return optimizer + except Exception as e: + # Fallback to simple FP32 optimizer if Megatron optimizer fails + if mpu.get_data_parallel_rank() == 0: + print( + f'Warning: Failed to create Megatron optimizer ({e}), falling back to PyTorch AdamW' + ) + + params = [p for p in self.model.parameters() if p.requires_grad] + return torch.optim.AdamW(params, + lr=lr, + weight_decay=kwargs.get( + 'weight_decay', 0.0)) + + def _get_trainable_parameters( + self, + adapter_name: str = _default_adapter_name + ) -> Dict[str, nn.Parameter]: + """Get trainable parameters. + + Args: + adapter_name: Name of adapter. + + Returns: + Dict mapping parameter names to parameters. + """ + is_default = adapter_name == _default_adapter_name + pattern = re.compile(rf'\.lora_\w+\.{re.escape(adapter_name)}\.') + + params = {} + model = self.strategy.unwrap_model(self.model) + for name, param in model.named_parameters(): + if param.requires_grad and (pattern.search(name) or is_default): + params[name] = param + return params + + @remote_function(dispatch='all') + def set_lr_scheduler(self, scheduler_cls: Union[Type[LRScheduler], str], + **kwargs): + """Set learning rate scheduler. + + Args: + scheduler_cls: Scheduler class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(scheduler_cls, str): + if hasattr(torch.optim.lr_scheduler, scheduler_cls): + scheduler_cls = getattr(torch.optim.lr_scheduler, + scheduler_cls) + else: + scheduler_cls = Plugin.load_plugin(scheduler_cls, LRScheduler) + + optimizer = optimizer_config.optimizer + assert optimizer is not None, 'Set optimizer before setting lr_scheduler' + optimizer_config.lr_scheduler = scheduler_cls(optimizer, **kwargs) + + @remote_function(dispatch='all', sync=True) + def save(self, output_dir: str, **kwargs): + """Save model checkpoint. + + Args: + output_dir: Output directory. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + save_format = kwargs.pop('save_format', 'hf') # 'hf' or 'megatron' + + if save_format == 'hf': + self._save_hf_format(output_dir, adapter_name) + else: + self._save_megatron_format(output_dir, adapter_name) + + self._save_tokenizer(output_dir, adapter_name) + + def _save_hf_format(self, output_dir: str, adapter_name: str): + """Save in HuggingFace format using bridge adapter. + + For distributed training: + - All PP ranks participate in export (each has different layers) + - Only DP rank 0 actually writes to disk + - Uses barrier for synchronization + + For LoRA training: + - Saves in PEFT format (adapter_model.safetensors + adapter_config.json) + """ + from twinkle.megatron.model.bridge import TwinkleBridgeAdapter + import os + + # Check if this is LoRA training (has adapter_name other than default) + is_lora = adapter_name and adapter_name != '' + is_peft_format = is_lora + + # Create output directory on rank 0 only + try: + from megatron.core import parallel_state as mpu + dp_rank = mpu.get_data_parallel_rank() if mpu.is_initialized( + ) else 0 + except (ImportError, AssertionError): + dp_rank = 0 + + if dp_rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # Synchronize before saving + if dist.is_initialized(): + dist.barrier() + + # Calculate padded vocab size + padded_vocab_size = self._pad_vocab_size(self.hf_config.vocab_size) \ + if hasattr(self, '_pad_vocab_size') else None + + # Use TwinkleBridgeAdapter for weight conversion + # All ranks participate - bridge handles which ranks write + adapter = TwinkleBridgeAdapter( + hf_config=self.hf_config, + tp_size=self.strategy.tp_size, + pp_size=self.strategy.pp_size, + ep_size=self.strategy.ep_size, + model_path=self._model_path + if hasattr(self, '_model_path') else self.model_id, + padded_vocab_size=padded_vocab_size, + ) + + # Get the model (unwrap if DDP wrapped) + model = self.strategy.unwrap_model(self.model) + + # Use bridge to save weights + adapter.save_weights([model], + output_dir, + is_peft_format=is_peft_format) + + # Save config on rank 0 only + if dp_rank == 0: + self.hf_config.save_pretrained(output_dir) + + def _pad_vocab_size(self, vocab_size: int) -> int: + """Pad vocab size for tensor parallelism.""" + divisor = self.strategy.tp_size * 128 + return ((vocab_size + divisor - 1) // divisor) * divisor + + def _save_megatron_format(self, output_dir: str, adapter_name: str): + """Save in Megatron checkpoint format.""" + import os + os.makedirs(output_dir, exist_ok=True) + + model = self.strategy.unwrap_model(self.model) + state_dict = self._get_trainable_parameters(adapter_name) + + # Convert to CPU + cpu_state_dict = {k: v.cpu() for k, v in state_dict.items()} + + # Save with rank info for distributed checkpointing + rank = dist.get_rank() if dist.is_initialized() else 0 + checkpoint_path = os.path.join(output_dir, f'model_rank{rank}.pt') + torch.save(cpu_state_dict, checkpoint_path) + + def _save_tokenizer(self, + output_dir: str, + adapter_name: str = _default_adapter_name): + """Save tokenizer.""" + optimizer_config = self.optimizer_group.get(adapter_name) + if optimizer_config and optimizer_config.template: + optimizer_config.template.tokenizer.save_pretrained(output_dir) + + @remote_function(execute='first') + def get_state_dict(self, **kwargs): + """Get trainable state dict. + + Args: + **kwargs: Additional arguments. + + Returns: + State dict of trainable parameters. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + return self._get_trainable_parameters(adapter_name) + + _peft_patched = False + + @classmethod + def _patch_peft_for_megatron(cls): + """Patch PEFT's BaseTuner to handle Megatron's TransformerConfig. + + Megatron's TransformerConfig doesn't have a .get() method like HuggingFace + configs. This patch handles the AttributeError that occurs when PEFT tries + to check tie_word_embeddings. + """ + if cls._peft_patched: + return + + from typing import List + import torch.nn as nn + from peft.tuners.tuners_utils import BaseTuner + + _origin_get_tied_target_modules = BaseTuner._get_tied_target_modules + + def _get_tied_target_modules(self, model: nn.Module) -> List[str]: + try: + return _origin_get_tied_target_modules(self, model) + except AttributeError: + # Megatron's TransformerConfig doesn't have .get() method + # Check share_embeddings_and_output_weights instead + tied_target_modules = [] + if getattr(model, 'share_embeddings_and_output_weights', + False): + for target_module in self.targeted_module_names: + module_name = target_module.split('.')[-1] + if module_name in [ + 'output_layer', 'embedding', 'word_embeddings' + ]: + tied_target_modules.append(target_module) + return tied_target_modules + + BaseTuner._get_tied_target_modules = _get_tied_target_modules + cls._peft_patched = True + + @remote_function(dispatch='all', sync=True) + def add_adapter_to_model( + self, + adapter_name: str, + config_or_dir: Union[Any, str], + **kwargs, + ): + """Add LoRA adapter to model. + + Args: + adapter_name: Name of the adapter. + config_or_dir: LoRA config or path to saved adapter. + **kwargs: Additional arguments. + """ + from twinkle.megatron.utils import (prepare_lora_model, patch_deepcopy, + get_target_modules, + set_linear_is_expert) + + # Patch PEFT BaseTuner to handle Megatron's TransformerConfig + # which doesn't have a .get() method like HuggingFace configs + self._patch_peft_for_megatron() + + assert adapter_name, 'Use a non-empty adapter_name' + + model = self.strategy.unwrap_model(self.model) + + # Mark expert layers for MoE models + set_linear_is_expert(model) + + if isinstance(config_or_dir, str): + # Load from path + config_or_dir = HubOperation.download_model(config_or_dir) + from peft import PeftModel + model = PeftModel.from_pretrained(model, + config_or_dir, + adapter_name=adapter_name, + is_trainable=kwargs.get( + 'is_trainable', True)) + else: + # Create from config + from peft import LoraConfig, get_peft_model + + if not isinstance(config_or_dir, LoraConfig): + # Convert dict to LoraConfig + config_or_dir = LoraConfig(**config_or_dir) + + # Expand target_modules (e.g., 'all-linear' -> actual module names) + if config_or_dir.target_modules: + if isinstance(config_or_dir.target_modules, str): + target_modules = [config_or_dir.target_modules] + else: + target_modules = list(config_or_dir.target_modules) + + expanded_modules = get_target_modules(model, target_modules) + config_or_dir.target_modules = expanded_modules + + with patch_deepcopy(): + model = get_peft_model(model, + config_or_dir, + adapter_name=adapter_name) + + # Update model reference + if self._model_wrapped: + if isinstance(self.model, MegatronDDP): + self.model.module = model + else: + self.model = model + + # Add finish_grad_sync method for Megatron's finalize_model_grads compatibility + # This is needed because Megatron's forward_backward_func calls finish_grad_sync + # on model chunks, but PEFT models don't have this method by default + if not hasattr(self.model, 'finish_grad_sync'): + + def finish_grad_sync(): + """Synchronize gradients across DP ranks for non-DDP models. + + This is a compatibility shim for Megatron's finalize_model_grads. + For PEFT/LoRA models, we manually all-reduce only trainable (LoRA) gradients. + + Optimizations: + 1. Only process gradients of trainable parameters (LoRA weights) + 2. Skip if DP size is 1 (no synchronization needed) + 3. Use coalesced all-reduce for efficiency + """ + dp_world_size = mpu.get_data_parallel_world_size() + if dp_world_size <= 1: + return # No sync needed for DP=1 + + dp_cp_group = mpu.get_data_parallel_group( + with_context_parallel=True) + grads = [] + + # Only collect gradients from trainable parameters (LoRA weights) + # This is much faster than iterating all parameters + for param in self.model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if not grads: + return # No gradients to sync + + # Coalesced all-reduce for efficiency + from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce(coalesced, + op=dist.ReduceOp.AVG, + group=dp_cp_group) + + # Copy back synchronized gradients + for grad, synced in zip( + grads, _unflatten_dense_tensors(coalesced, grads)): + grad.copy_(synced) + + self.model.finish_grad_sync = finish_grad_sync + + # Create optimizer group for adapter + self.optimizer_group[adapter_name] = MegatronOptimizerGroup() + self.optimizer_group[adapter_name].adapter_name = adapter_name + self.optimizer_group[adapter_name].adapter_config = config_or_dir + self.optimizer_group[ + adapter_name].gradient_accumulation_steps = kwargs.get( + 'gradient_accumulation_steps', 1) + + # Copy settings from default + default_config = self.optimizer_group.get(_default_adapter_name) + if default_config: + if default_config.template: + self.optimizer_group[ + adapter_name].template = default_config.template + if default_config.processor: + self.optimizer_group[ + adapter_name].processor = default_config.processor + if default_config.loss_instance: + self.optimizer_group[ + adapter_name].loss_instance = default_config.loss_instance + + @remote_function(dispatch='all') + def set_template(self, template_cls: Union[Type[template.Template], str], + **kwargs): + """Set template for input encoding. + + Args: + template_cls: Template class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(template_cls, str): + if hasattr(template, template_cls): + template_cls = getattr(template, template_cls) + else: + template_cls = Plugin.load_plugin(template_cls, + template.Template) + optimizer_config.template = template_cls(self.model_id, **kwargs) + + @remote_function(dispatch='all') + def set_processor(self, processor_cls: Union[Type[InputProcessor], str], + **kwargs): + """Set input processor. + + Args: + processor_cls: Processor class or string name. + **kwargs: Additional arguments. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + if isinstance(processor_cls, str): + if hasattr(twinkle.processor, processor_cls): + processor_cls = getattr(twinkle.processor, processor_cls) + else: + processor_cls = Plugin.load_plugin(processor_cls, + InputProcessor) + optimizer_config.processor = processor_cls( + device_mesh=self.device_mesh, **kwargs) + + @remote_function(execute='first') + def get_train_configs(self, **kwargs): + """Get training configuration summary. + + Args: + **kwargs: Additional arguments. + + Returns: + Configuration summary string. + """ + adapter_name = kwargs.pop('adapter_name', _default_adapter_name) + optimizer_config = self.optimizer_group[adapter_name] + + expr = f'Backend: Megatron-Core\n' + expr += f'TP size: {self.strategy.tp_size}\n' + expr += f'PP size: {self.strategy.pp_size}\n' + expr += f'CP size: {self.strategy.cp_size}\n' + expr += f'EP size: {self.strategy.ep_size}\n' + expr += f'Sequence Parallel: {self.strategy.sequence_parallel}\n' + + if optimizer_config.adapter_config is not None: + config = optimizer_config.adapter_config.__dict__ + config = { + key: str(value) + for key, value in config.items() if value is not None + } + expr += f'Adapter config:\n{json.dumps(config, indent=2, ensure_ascii=False)}\n' + + if optimizer_config.optimizer: + expr += f'Optimizer: {optimizer_config.optimizer.__class__.__name__}\n' + expr += f'Learning rate: {optimizer_config.optimizer.defaults.get("lr", "N/A")}\n' + if optimizer_config.lr_scheduler: + expr += f'LR scheduler: {optimizer_config.lr_scheduler.__class__.__name__}\n' + expr += f'Gradient accumulation steps: {optimizer_config.gradient_accumulation_steps}\n' + + return expr + + def __repr__(self): + return (f"MegatronModel(model_id='{self.model_id}', " + f'tp={self.strategy.tp_size}, pp={self.strategy.pp_size}, ' + f'cp={self.strategy.cp_size}, ep={self.strategy.ep_size})') diff --git a/src/twinkle/model/strategy/__init__.py b/src/twinkle/model/strategy/__init__.py index 6c67d231..79af2791 100644 --- a/src/twinkle/model/strategy/__init__.py +++ b/src/twinkle/model/strategy/__init__.py @@ -1,3 +1,4 @@ # Copyright (c) ModelScope Contributors. All rights reserved. from .base import TrainStrategy from .accelerate import AccelerateStrategy +from .megatron import MegatronStrategy diff --git a/src/twinkle/model/strategy/megatron.py b/src/twinkle/model/strategy/megatron.py new file mode 100644 index 00000000..5f0ad751 --- /dev/null +++ b/src/twinkle/model/strategy/megatron.py @@ -0,0 +1,765 @@ +# Copyright (c) twinkle authors. All rights reserved. +"""Megatron training strategy for distributed model parallelism.""" +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import torch +import torch.distributed as dist +import torch.nn as nn + +from .base import TrainStrategy + +try: + from twinkle import DeviceMesh +except ImportError: + DeviceMesh = None + +try: + import megatron.core + from megatron.core import parallel_state + from megatron.core.distributed import DistributedDataParallel as MegatronDDP + from packaging import version + MEGATRON_AVAILABLE = True + mcore_013 = version.parse( + megatron.core.__version__) >= version.parse('0.13.0rc0') +except ImportError: + MEGATRON_AVAILABLE = False + mcore_013 = False + + +def check_megatron_available(): + """Check if Megatron-Core is available.""" + if not MEGATRON_AVAILABLE: + raise ImportError( + 'Megatron-Core is not installed. Please install it with: ' + 'pip install megatron-core') + + +class MegatronStrategy(TrainStrategy): + """Strategy for Megatron-Core based distributed training. + + Supports Tensor Parallel (TP), Pipeline Parallel (PP), Context Parallel (CP), + Expert Parallel (EP), and Data Parallel (DP). + + This strategy integrates with twinkle's DeviceMesh to provide a unified + interface for distributed training configuration. + """ + def __init__( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + context_parallel_size: int = 1, + expert_model_parallel_size: int = 1, + expert_tensor_parallel_size: Optional[int] = None, + virtual_pipeline_model_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + use_distributed_optimizer: bool = True, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + params_dtype: Optional[str] = None, + device_mesh: Optional['DeviceMesh'] = None, + megatron_args: Optional[Dict[str, Any]] = None, + ): + """Initialize MegatronStrategy. + + Args: + tensor_model_parallel_size: Degree of tensor model parallelism. + pipeline_model_parallel_size: Degree of pipeline model parallelism. + context_parallel_size: Degree of context parallelism. + expert_model_parallel_size: Degree of expert model parallelism for MoE. + expert_tensor_parallel_size: Degree of expert tensor parallelism. + virtual_pipeline_model_parallel_size: Virtual pipeline model parallel size. + sequence_parallel: Enable sequence parallelism. + use_distributed_optimizer: Use Megatron's distributed optimizer. + mixed_precision: Mixed precision mode. + params_dtype: Parameter dtype string (e.g., 'bf16', 'fp32'). + device_mesh: Twinkle DeviceMesh for distributed configuration. + megatron_args: Additional Megatron arguments. + """ + check_megatron_available() + + # If device_mesh is provided, extract parallel sizes from it + if device_mesh is not None: + tensor_model_parallel_size = self._get_dim_from_mesh( + device_mesh, 'tp', tensor_model_parallel_size) + pipeline_model_parallel_size = self._get_dim_from_mesh( + device_mesh, 'pp', pipeline_model_parallel_size) + context_parallel_size = self._get_dim_from_mesh( + device_mesh, 'cp', context_parallel_size) + expert_model_parallel_size = self._get_dim_from_mesh( + device_mesh, 'ep', expert_model_parallel_size) + + self.tp_size = tensor_model_parallel_size + self.pp_size = pipeline_model_parallel_size + self.cp_size = context_parallel_size + self.ep_size = expert_model_parallel_size + self.etp_size = expert_tensor_parallel_size or tensor_model_parallel_size + self.vp_size = virtual_pipeline_model_parallel_size + self.sequence_parallel = sequence_parallel + self.use_distributed_optimizer = use_distributed_optimizer + self.mixed_precision = mixed_precision + self.params_dtype = params_dtype + self.device_mesh = device_mesh + self.megatron_args = megatron_args or {} + + self._initialized = False + self._parallel_state = None + + @staticmethod + def _get_dim_from_mesh(device_mesh: 'DeviceMesh', dim_name: str, + default: int) -> int: + """Get dimension size from device mesh. + + Args: + device_mesh: The device mesh. + dim_name: Name of the dimension. + default: Default value if dimension not found. + + Returns: + Dimension size. + """ + if device_mesh is None: + return default + if hasattr(device_mesh, 'has_dim') and device_mesh.has_dim(dim_name): + return device_mesh.get_dim_size(dim_name) + return default + + @classmethod + def from_device_mesh( + cls, + device_mesh: 'DeviceMesh', + sequence_parallel: bool = False, + use_distributed_optimizer: bool = True, + mixed_precision: Literal['no', 'fp16', 'bf16'] = 'bf16', + **kwargs, + ) -> 'MegatronStrategy': + """Create MegatronStrategy from twinkle DeviceMesh. + + Args: + device_mesh: Twinkle DeviceMesh with dimension names like 'tp', 'pp', 'cp', 'ep', 'dp'. + sequence_parallel: Enable sequence parallelism. + use_distributed_optimizer: Use Megatron's distributed optimizer. + mixed_precision: Mixed precision mode. + **kwargs: Additional arguments. + + Returns: + MegatronStrategy instance. + """ + return cls( + device_mesh=device_mesh, + sequence_parallel=sequence_parallel, + use_distributed_optimizer=use_distributed_optimizer, + mixed_precision=mixed_precision, + **kwargs, + ) + + def initialize(self, **kwargs) -> None: + """Initialize Megatron parallel state. + + This method handles both local (torchrun) and Ray modes: + + **Local mode**: + - torch.distributed is already initialized by torchrun + - Just initialize mpu.initialize_model_parallel() + + **Ray mode**: + - Read RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT from environment + - Initialize torch.distributed with these values + - Then initialize mpu.initialize_model_parallel() + + This allows the same MegatronModel code to work in both modes. + """ + if self._initialized: + return + + import os + from datetime import timedelta + + # Determine execution mode + twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') + + # Initialize torch.distributed if not already done + if not dist.is_initialized(): + if twinkle_mode == 'ray': + # Ray mode: use environment variables set by RayHelper + rank = int(os.environ.get('RANK', '0')) + world_size = int(os.environ.get('WORLD_SIZE', '1')) + master_addr = os.environ.get('MASTER_ADDR', 'localhost') + master_port = os.environ.get('MASTER_PORT', '29500') + local_rank = int(os.environ.get('LOCAL_RANK', '0')) + + # Set CUDA device before init_process_group + torch.cuda.set_device(local_rank) + + # Initialize process group + dist.init_process_group( + backend='nccl', + init_method=f'tcp://{master_addr}:{master_port}', + rank=rank, + world_size=world_size, + timeout=timedelta(minutes=10), + ) + else: + # Local mode: torchrun should have set up distributed + # If not, initialize with default settings + dist.init_process_group(backend='nccl') + + world_size = dist.get_world_size() + + # Validate parallel configuration + total_model_parallel = self.tp_size * self.pp_size * self.cp_size + if world_size % total_model_parallel != 0: + raise ValueError( + f'World size ({world_size}) must be divisible by ' + f'tp_size * pp_size * cp_size ({total_model_parallel})') + + # Initialize Megatron parallel state + init_kwargs = { + 'tensor_model_parallel_size': self.tp_size, + 'pipeline_model_parallel_size': self.pp_size, + 'context_parallel_size': self.cp_size, + } + + if self.vp_size is not None: + init_kwargs['virtual_pipeline_model_parallel_size'] = self.vp_size + + # Handle MoE parallelism + if self.ep_size > 1: + init_kwargs['expert_model_parallel_size'] = self.ep_size + if mcore_013: + init_kwargs['expert_tensor_parallel_size'] = self.etp_size + + parallel_state.initialize_model_parallel(**init_kwargs) + + self._parallel_state = parallel_state + self._initialized = True + + # Set CUDA device (may be redundant in Ray mode, but safe) + local_rank = dist.get_rank() % torch.cuda.device_count() + torch.cuda.set_device(local_rank) + + def destroy(self) -> None: + """Destroy parallel state and clean up resources.""" + if self._initialized and self._parallel_state is not None: + self._parallel_state.destroy_model_parallel() + self._initialized = False + + @property + def tp_rank(self) -> int: + """Get tensor parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_tensor_model_parallel_rank() + + @property + def pp_rank(self) -> int: + """Get pipeline parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_pipeline_model_parallel_rank() + + @property + def dp_rank(self) -> int: + """Get data parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_data_parallel_rank() + + @property + def cp_rank(self) -> int: + """Get context parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_context_parallel_rank() + + @property + def ep_rank(self) -> int: + """Get expert parallel rank.""" + if not self._initialized: + return 0 + return self._parallel_state.get_expert_model_parallel_rank() + + @property + def dp_size(self) -> int: + """Get data parallel size.""" + if not self._initialized: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + return world_size // (self.tp_size * self.pp_size * self.cp_size) + return self._parallel_state.get_data_parallel_world_size() + + @property + def tp_group(self): + """Get tensor parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_tensor_model_parallel_group() + + @property + def dp_group(self): + """Get data parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_data_parallel_group() + + @property + def pp_group(self): + """Get pipeline parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_pipeline_model_parallel_group() + + @property + def cp_group(self): + """Get context parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_context_parallel_group() + + @property + def ep_group(self): + """Get expert parallel process group.""" + if not self._initialized: + return None + return self._parallel_state.get_expert_model_parallel_group() + + def is_pipeline_first_stage(self) -> bool: + """Check if current rank is pipeline first stage.""" + if not self._initialized: + return True + return self._parallel_state.is_pipeline_first_stage() + + def is_pipeline_last_stage(self) -> bool: + """Check if current rank is pipeline last stage.""" + if not self._initialized: + return True + return self._parallel_state.is_pipeline_last_stage() + + def is_data_parallel_main_rank(self) -> bool: + """Check if current rank is the main rank in data parallel group.""" + if not self._initialized: + return True + return self.dp_rank == 0 + + def get_params_dtype(self) -> torch.dtype: + """Get parameter dtype based on configuration. + + Returns: + PyTorch dtype for model parameters. + """ + if self.params_dtype is not None: + dtype_map = { + 'fp32': torch.float32, + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + } + return dtype_map.get(self.params_dtype, torch.bfloat16) + + if self.mixed_precision == 'bf16': + return torch.bfloat16 + elif self.mixed_precision == 'fp16': + return torch.float16 + return torch.float32 + + def _get_transformer_config(self, model: nn.Module): + """Get TransformerConfig from model, handling PEFT wrappers. + + Args: + model: The model (may be wrapped with PEFT). + + Returns: + TransformerConfig if found, None otherwise. + """ + # Direct config attribute + config = getattr(model, 'config', None) + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): + return config + + # PEFT model: model.base_model.model.config + if hasattr(model, 'base_model'): + base = model.base_model + if hasattr(base, 'model'): + config = getattr(base.model, 'config', None) + if config is not None and hasattr( + config, 'tensor_model_parallel_size'): + return config + # Try base.config + config = getattr(base, 'config', None) + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): + return config + + # Wrapped model: model.model.config + if hasattr(model, 'model'): + config = getattr(model.model, 'config', None) + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): + return config + + # Recursive search through modules + for name, module in model.named_modules(): + config = getattr(module, 'config', None) + if config is not None and hasattr(config, + 'tensor_model_parallel_size'): + return config + + return None + + def wrap_model( + self, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + use_distributed_optimizer: bool = True, + ) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]: + """Wrap model with Megatron DDP for data parallelism. + + This method behaves differently based on twinkle's execution mode: + + **Local mode (torchrun)**: + - Uses Megatron native DDP wrapping + - All processes are synchronized by torchrun, so collective ops work + + **Ray mode**: + - Currently skips DDP wrapping to avoid deadlocks + - Ray's asynchronous actor model makes collective synchronization hard + - Each DP replica trains independently + + **Transformers/Accelerate comparison**: + - Accelerate's `prepare()` works in Ray because it's a local operation + - Megatron DDP's `broadcast_params()` is a collective that needs sync + + Args: + model: The Megatron model (already has TP/PP via TransformerConfig). + optimizer: Optional optimizer. + use_distributed_optimizer: Whether to use distributed optimizer. + + Returns: + Tuple of (wrapped_model, optimizer). + """ + if not self._initialized: + self.initialize() + + # Determine execution mode + import os + twinkle_mode = os.environ.get('TWINKLE_MODE', 'local') + + # Check DP world size + dp_group = self.dp_group + dp_world_size = 1 + if dp_group is not None: + dp_world_size = dist.get_world_size(dp_group) + + if dp_world_size <= 1: + # No DP needed (single GPU or TP-only) + return model, optimizer + + if twinkle_mode == 'ray': + # In Ray mode, skip DDP for now due to collective sync issues + # TODO: Implement Ray-compatible DDP with barrier synchronization + import warnings + warnings.warn( + 'Skipping Megatron DDP in Ray mode. Each DP replica trains independently. ' + 'For synchronized training, use torchrun (TWINKLE_MODE=local).' + ) + return model, optimizer + + # Local mode (torchrun): Use Megatron native DDP + return self._wrap_with_megatron_ddp(model, optimizer, + use_distributed_optimizer) + + def _wrap_with_megatron_ddp( + self, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer], + use_distributed_optimizer: bool, + ) -> Tuple[nn.Module, Optional[torch.optim.Optimizer]]: + """ + Wrap model with Megatron native DDP (for torchrun mode). + """ + from megatron.core.distributed import DistributedDataParallelConfig + from megatron.core.transformer.module import Float16Module + + # Get TransformerConfig from model + config = self._get_transformer_config(model) + if config is None: + import warnings + warnings.warn( + 'Could not find TransformerConfig. Skipping DDP wrapping. ' + 'Gradient sync will need to be done manually.') + return model, optimizer + + # Ensure model is on GPU + try: + model_device = next(model.parameters()).device + if model_device.type == 'cpu': + local_rank = dist.get_rank() % torch.cuda.device_count() + model = model.to(f'cuda:{local_rank}') + except StopIteration: + pass # No parameters + + # Wrap with Float16Module for mixed precision (like Megatron's get_model) + if (config.fp16 + or config.bf16) and not isinstance(model, Float16Module): + # Check if the inner model (for PEFT) needs wrapping + inner_model = model + if hasattr(model, 'base_model') and hasattr( + model.base_model, 'model'): + inner_model = model.base_model.model + + # Only wrap if not already wrapped + if not isinstance(inner_model, Float16Module): + # For PEFT models, we can't easily wrap the inner model + # Just proceed without Float16Module + if not hasattr(model, 'base_model'): + model = Float16Module(config, model) + + # Create DDP config + ddp_config = DistributedDataParallelConfig( + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + use_distributed_optimizer=use_distributed_optimizer, + ) + + # Wrap with MegatronDDP + # TODO: multi-tenant ddp + try: + wrapped_model = MegatronDDP( + config=config, + ddp_config=ddp_config, + module=model, + ) + + # Broadcast params from data parallel src rank + # In torchrun mode, all ranks enter here simultaneously, so this works + wrapped_model.broadcast_params() + + return wrapped_model, optimizer + + except Exception as e: + import warnings + warnings.warn( + f'Failed to wrap with Megatron DDP: {e}. Using unwrapped model.' + ) + return model, optimizer + + def unwrap_model(self, model: nn.Module) -> nn.Module: + """Unwrap the distributed model to get the base model. + + Args: + model: The wrapped model. + + Returns: + The unwrapped base model. + """ + if isinstance(model, MegatronDDP): + return model.module + + from torch.nn.parallel import DistributedDataParallel as TorchDDP + if isinstance(model, TorchDDP): + return model.module + + return model + + def get_model_config( + self, + hidden_size: int, + num_attention_heads: int, + num_layers: int, + ffn_hidden_size: Optional[int] = None, + num_query_groups: Optional[int] = None, + vocab_size: int = 32000, + max_position_embeddings: int = 4096, + num_experts: Optional[int] = None, + moe_router_topk: int = 2, + **kwargs, + ): + """Create a Megatron TransformerConfig. + + Args: + hidden_size: Hidden dimension size. + num_attention_heads: Number of attention heads. + num_layers: Number of transformer layers. + ffn_hidden_size: FFN hidden size (default: 4 * hidden_size). + num_query_groups: Number of KV heads for GQA. + vocab_size: Vocabulary size. + max_position_embeddings: Maximum sequence length. + num_experts: Number of MoE experts. + moe_router_topk: Top-k for MoE routing. + **kwargs: Additional config arguments. + + Returns: + Megatron TransformerConfig. + """ + from megatron.core.transformer import TransformerConfig + + config = TransformerConfig( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=num_query_groups or num_attention_heads, + ffn_hidden_size=ffn_hidden_size or 4 * hidden_size, + use_cpu_initialization=True, + params_dtype=self.get_params_dtype(), + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + context_parallel_size=self.cp_size, + expert_model_parallel_size=self.ep_size, + sequence_parallel=self.sequence_parallel, + num_moe_experts=num_experts, + moe_router_topk=moe_router_topk, + **kwargs, + ) + + return config + + def sync_gradients(self, model: Optional[nn.Module] = None) -> None: + """Synchronize gradients across data parallel group. + + For DDP-wrapped models, gradients are synchronized automatically. + For non-DDP models (e.g., PEFT models), this performs manual all-reduce. + + Args: + model: Optional model to sync gradients for. If None, only barrier. + """ + if not self._initialized: + return + + dp_group = self.dp_group + if dp_group is None: + return + + dp_size = dist.get_world_size(dp_group) + if dp_size <= 1: + return + + if model is not None: + # Manual gradient synchronization for non-DDP models (e.g., PEFT) + self.all_reduce_gradients(model) + else: + # Just barrier for DDP models + dist.barrier(dp_group) + + def all_reduce_gradients(self, model: nn.Module) -> None: + """All-reduce gradients of trainable parameters across data parallel group. + + This is used for PEFT/LoRA models that are not wrapped with DDP. + Gradients are averaged across all DP ranks. + + Args: + model: The model whose gradients to synchronize. + """ + if not self._initialized: + return + + dp_group = self.dp_group + if dp_group is None: + return + + dp_size = dist.get_world_size(dp_group) + if dp_size <= 1: + return + + # Collect gradients from trainable parameters + grads = [] + for param in model.parameters(): + if param.requires_grad and param.grad is not None: + grads.append(param.grad.data) + + if not grads: + return + + # Flatten all gradients into a single tensor for efficient communication + # This reduces the number of all-reduce operations + flat_grads = torch.cat([g.contiguous().view(-1) for g in grads]) + + # All-reduce and average + dist.all_reduce(flat_grads, op=dist.ReduceOp.SUM, group=dp_group) + flat_grads.div_(dp_size) + + # Unflatten back to original gradient tensors + offset = 0 + for grad in grads: + numel = grad.numel() + grad.copy_(flat_grads[offset:offset + numel].view_as(grad)) + offset += numel + + def all_reduce( + self, + tensor: torch.Tensor, + op: dist.ReduceOp = dist.ReduceOp.SUM, + group: Optional[dist.ProcessGroup] = None, + ) -> torch.Tensor: + """All-reduce tensor across specified group. + + Args: + tensor: Input tensor. + op: Reduce operation. + group: Process group (defaults to data parallel group). + + Returns: + Reduced tensor. + """ + if not self._initialized: + return tensor + + if group is None: + group = self.dp_group + + if group is not None: + dist.all_reduce(tensor, op=op, group=group) + + return tensor + + def broadcast( + self, + tensor: torch.Tensor, + src: int = 0, + group: Optional[dist.ProcessGroup] = None, + ) -> torch.Tensor: + """Broadcast tensor from source rank. + + Args: + tensor: Input tensor. + src: Source rank. + group: Process group (defaults to data parallel group). + + Returns: + Broadcasted tensor. + """ + if not self._initialized: + return tensor + + if group is None: + group = self.dp_group + + if group is not None: + dist.broadcast(tensor, src=src, group=group) + + return tensor + + def get_parallel_info(self) -> Dict[str, Any]: + """Get parallelism configuration information. + + Returns: + Dict with parallel configuration details. + """ + return { + 'tp_size': self.tp_size, + 'pp_size': self.pp_size, + 'cp_size': self.cp_size, + 'ep_size': self.ep_size, + 'etp_size': self.etp_size, + 'vp_size': self.vp_size, + 'dp_size': self.dp_size, + 'sequence_parallel': self.sequence_parallel, + 'use_distributed_optimizer': self.use_distributed_optimizer, + 'mixed_precision': self.mixed_precision, + 'tp_rank': self.tp_rank, + 'pp_rank': self.pp_rank, + 'dp_rank': self.dp_rank, + 'cp_rank': self.cp_rank, + 'ep_rank': self.ep_rank, + } + + def __repr__(self) -> str: + return (f'MegatronStrategy(tp={self.tp_size}, pp={self.pp_size}, ' + f'cp={self.cp_size}, ep={self.ep_size}, dp={self.dp_size}, ' + f'sequence_parallel={self.sequence_parallel})') diff --git a/src/twinkle/utils/framework.py b/src/twinkle/utils/framework.py index 0c4f8bcf..1ade5acc 100644 --- a/src/twinkle/utils/framework.py +++ b/src/twinkle/utils/framework.py @@ -195,6 +195,7 @@ def to_local_tensor(tensor: 'torch.Tensor') -> 'torch.Tensor': Returns: A local torch.Tensor. """ + import torch if hasattr(tensor, 'full_tensor'): # DTensor from torch.distributed.tensor return tensor.full_tensor() diff --git a/test_ray_configs.py b/test_ray_configs.py new file mode 100644 index 00000000..5bf6be09 --- /dev/null +++ b/test_ray_configs.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python +"""Test script for Ray mode with various parallelism configurations. + +Records loss, memory usage, and training time. +""" +import os +import sys +import time +import subprocess +import re + +# Test configurations: (tp_size, pp_size, num_gpus, name) +CONFIGS = [ + (2, 2, 4, "TP=2_PP=2"), + (4, 1, 4, "TP=4_PP=1"), + (1, 4, 4, "TP=1_PP=4"), + (2, 1, 2, "TP=2_PP=1"), +] + +MODEL = "ms://Qwen/Qwen2.5-0.5B-Instruct" +MAX_STEPS = 5 +TIMEOUT = 600 # 10 minutes per test + +def run_test(mode, tp_size, pp_size, num_gpus, name): + """Run a single test configuration.""" + env = os.environ.copy() + env["MEGATRON_LM_PATH"] = "/mnt/nas2/hujinghan.hjh/Megatron-LM" + env["PYTHONPATH"] = "/mnt/nas2/hujinghan.hjh/Megatron-LM:/mnt/nas2/hujinghan.hjh/twinkle/src:" + env.get("PYTHONPATH", "") + env["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(num_gpus)) + env["TRUST_REMOTE_CODE"] = "1" + + log_file = f"/mnt/nas2/hujinghan.hjh/twinkle/test_{mode}_{name}.log" + + if mode == "ray": + cmd = [ + "/mnt/nas2/anaconda3/envs/hjh/bin/python", + "cookbook/megatron/lora.py", + "--mode", "ray", + "--tp_size", str(tp_size), + "--pp_size", str(pp_size), + "--num_gpus", str(num_gpus), + "--model", MODEL, + "--max_steps", str(MAX_STEPS), + ] + else: + # Find an available port + import socket + with socket.socket() as s: + s.bind(("", 0)) + port = s.getsockname()[1] + + cmd = [ + "/mnt/nas2/anaconda3/envs/hjh/bin/python", "-m", "torch.distributed.run", + "--nproc_per_node", str(num_gpus), + "--master_port", str(port), + "cookbook/megatron/lora.py", + "--tp_size", str(tp_size), + "--pp_size", str(pp_size), + "--model", MODEL, + "--max_steps", str(MAX_STEPS), + ] + + print(f"\n{'='*60}") + print(f"Running: {mode} mode, {name}") + print(f"Command: {' '.join(cmd)}") + print(f"Log: {log_file}") + print(f"{'='*60}") + + start_time = time.time() + + with open(log_file, "w") as f: + try: + result = subprocess.run( + cmd, + cwd="/mnt/nas2/hujinghan.hjh/twinkle", + env=env, + stdout=f, + stderr=subprocess.STDOUT, + timeout=TIMEOUT, + ) + success = result.returncode == 0 + except subprocess.TimeoutExpired: + print(f" TIMEOUT after {TIMEOUT}s") + success = False + except Exception as e: + print(f" ERROR: {e}") + success = False + + elapsed = time.time() - start_time + + # Parse results from log + losses = [] + memory = None + + with open(log_file, "r") as f: + content = f.read() + + # Extract losses + for match in re.finditer(r"Step (\d+), loss: ([\d.]+)", content): + step = int(match.group(1)) + loss = float(match.group(2)) + losses.append((step, loss)) + + # Check for completion + completed = "Training completed!" in content + + return { + "mode": mode, + "config": name, + "tp": tp_size, + "pp": pp_size, + "gpus": num_gpus, + "losses": losses, + "elapsed": elapsed, + "success": success and completed, + "log_file": log_file, + } + + +def cleanup(): + """Kill any lingering processes.""" + os.system("pkill -9 -f 'lora.py|MegatronModel|ray' 2>/dev/null") + time.sleep(5) + + +def main(): + results = [] + + for tp, pp, gpus, name in CONFIGS: + # Test Ray mode + cleanup() + ray_result = run_test("ray", tp, pp, gpus, name) + results.append(ray_result) + + # Test Local mode + cleanup() + local_result = run_test("local", tp, pp, gpus, name) + results.append(local_result) + + cleanup() + + # Print summary + print("\n" + "="*80) + print("SUMMARY") + print("="*80) + print(f"{'Mode':<8} {'Config':<15} {'GPUs':<6} {'Status':<10} {'Time(s)':<10} {'Step0 Loss':<12} {'Step5 Loss':<12}") + print("-"*80) + + for r in results: + status = "✅ OK" if r["success"] else "❌ FAIL" + step0_loss = r["losses"][0][1] if len(r["losses"]) > 0 else "N/A" + step5_loss = r["losses"][-1][1] if len(r["losses"]) > 5 else "N/A" + if isinstance(step0_loss, float): + step0_loss = f"{step0_loss:.4f}" + if isinstance(step5_loss, float): + step5_loss = f"{step5_loss:.4f}" + print(f"{r['mode']:<8} {r['config']:<15} {r['gpus']:<6} {status:<10} {r['elapsed']:<10.1f} {step0_loss:<12} {step5_loss:<12}") + + print("="*80) + + # Save results to file + with open("/mnt/nas2/hujinghan.hjh/twinkle/test_results.txt", "w") as f: + f.write("Ray Mode Parallelism Test Results\n") + f.write("="*80 + "\n\n") + for r in results: + f.write(f"Mode: {r['mode']}, Config: {r['config']}, GPUs: {r['gpus']}\n") + f.write(f"Success: {r['success']}, Time: {r['elapsed']:.1f}s\n") + f.write(f"Losses: {r['losses']}\n") + f.write(f"Log: {r['log_file']}\n") + f.write("-"*40 + "\n") + + +if __name__ == "__main__": + main() diff --git a/twinkle b/twinkle deleted file mode 120000 index 5a08ecd1..00000000 --- a/twinkle +++ /dev/null @@ -1 +0,0 @@ -src/twinkle \ No newline at end of file